Coverage for oc_ds_converter / datasource / orcid_index.py: 80%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-25 18:06 +0000

1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it> 

2# 

3# SPDX-License-Identifier: ISC 

4 

5from __future__ import annotations 

6 

7import json 

8from collections import defaultdict 

9from concurrent.futures import ProcessPoolExecutor, as_completed 

10from csv import DictReader 

11from os import cpu_count, sep, walk 

12from os.path import exists 

13from typing import Protocol, cast 

14 

15import fakeredis 

16 

17from oc_ds_converter.datasource.redis import RedisDataSource 

18from oc_ds_converter.lib.console import create_progress 

19from oc_ds_converter.oc_idmanager import DOIManager 

20 

21 

22class OrcidIndexInterface(Protocol): 

23 def get_value(self, id_string: str, /) -> set[str] | None: ... 

24 def get_values_batch(self, ids: list[str], /) -> dict[str, set[str]]: ... 

25 

26 

27class OrcidIndexRedis: 

28 def __init__(self, testing: bool = False) -> None: 

29 if testing: 

30 self._r = fakeredis.FakeStrictRedis(decode_responses=True) 

31 else: 

32 self._redis = RedisDataSource("DOI-ORCID-INDEX") 

33 self._r = self._redis._r 

34 

35 def get_value(self, doi: str) -> set[str] | None: 

36 result = cast(set[str], self._r.smembers(doi)) 

37 if result: 

38 return result 

39 return None 

40 

41 def get_values_batch(self, dois: list[str]) -> dict[str, set[str]]: 

42 if not dois: 

43 return {} 

44 pipe = self._r.pipeline() 

45 for doi in dois: 

46 pipe.smembers(doi) 

47 results = pipe.execute() 

48 return { 

49 doi: cast(set[str], members) 

50 for doi, members in zip(dois, results) 

51 if members 

52 } 

53 

54 def add_values_batch(self, data: dict[str, set[str]]) -> None: 

55 pipe = self._r.pipeline() 

56 for doi, values in data.items(): 

57 if values: 

58 pipe.sadd(doi, *values) 

59 pipe.execute() 

60 

61 def has_data(self) -> bool: 

62 return cast(int, self._r.dbsize()) > 0 

63 

64 def clear(self) -> None: 

65 self._r.flushdb() 

66 

67 

68def _process_csv_file(csv_path: str) -> dict[str, set[str]]: 

69 """Process a single CSV file and return DOI -> ORCID mappings. 

70 

71 This function runs in a separate process for parallelization. 

72 """ 

73 doi_manager = DOIManager() 

74 result: dict[str, set[str]] = defaultdict(set) 

75 

76 with open(csv_path, 'r', encoding='utf-8') as f: 

77 reader = DictReader(f) 

78 for row in reader: 

79 raw_doi = row['id'] 

80 doi = doi_manager.normalise(raw_doi, include_prefix=True) 

81 if doi: 

82 result[doi].add(row['value']) 

83 

84 return dict(result) 

85 

86 

87def load_orcid_index_to_redis( 

88 orcid_index_dir: str, 

89 orcid_index_redis: OrcidIndexRedis, 

90 batch_size: int = 1000000, 

91 max_workers: int | None = None, 

92) -> None: 

93 if not exists(orcid_index_dir): 

94 return 

95 

96 files_to_process: list[str] = [] 

97 for cur_dir, _, cur_files in walk(orcid_index_dir): 

98 for cur_file in cur_files: 

99 if cur_file.endswith('.csv'): 

100 files_to_process.append(cur_dir + sep + cur_file) 

101 

102 if not files_to_process: 

103 return 

104 

105 if max_workers is None: 

106 max_workers = min(cpu_count() or 4, len(files_to_process)) 

107 

108 batch: dict[str, set[str]] = {} 

109 count = 0 

110 

111 with create_progress() as progress: 

112 task = progress.add_task( 

113 "[green]Loading DOI-ORCID index files", total=len(files_to_process) 

114 ) 

115 

116 with ProcessPoolExecutor(max_workers=max_workers) as executor: 

117 futures = { 

118 executor.submit(_process_csv_file, path): path 

119 for path in files_to_process 

120 } 

121 

122 for future in as_completed(futures): 

123 file_data = future.result() 

124 

125 # Merge results into batch 

126 for doi, values in file_data.items(): 

127 if doi in batch: 

128 batch[doi].update(values) 

129 else: 

130 batch[doi] = values 

131 count += len(values) 

132 

133 # Flush to Redis when batch is large enough 

134 if count >= batch_size: 

135 orcid_index_redis.add_values_batch(batch) 

136 batch = {} 

137 count = 0 

138 

139 progress.update(task, advance=1) 

140 

141 if batch: 

142 orcid_index_redis.add_values_batch(batch) 

143 

144 

145class PublishersRedis: 

146 MEMBER_PREFIX = "member:" 

147 DOI_PREFIX_KEY = "prefix:" 

148 

149 def __init__(self, testing: bool = False) -> None: 

150 if testing: 

151 self._r = fakeredis.FakeStrictRedis(decode_responses=True) 

152 else: 

153 self._redis = RedisDataSource("PUBLISHERS-INDEX") 

154 self._r = self._redis._r 

155 

156 def get_by_member(self, member_id: str) -> dict[str, str | set[str]] | None: 

157 key = f"{self.MEMBER_PREFIX}{member_id}" 

158 data = self._r.get(key) 

159 if data: 

160 result = json.loads(str(data)) 

161 result["prefixes"] = set(result["prefixes"]) 

162 return result 

163 return None 

164 

165 def get_by_prefix(self, prefix: str) -> dict[str, str | set[str]] | None: 

166 key = f"{self.DOI_PREFIX_KEY}{prefix}" 

167 member_id = self._r.get(key) 

168 if member_id: 

169 return self.get_by_member(str(member_id)) 

170 return None 

171 

172 def set_publisher(self, member_id: str, name: str, prefixes: set[str]) -> None: 

173 member_key = f"{self.MEMBER_PREFIX}{member_id}" 

174 data = {"name": name, "prefixes": list(prefixes)} 

175 self._r.set(member_key, json.dumps(data)) 

176 for prefix in prefixes: 

177 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}" 

178 self._r.set(prefix_key, member_id) 

179 

180 def set_publishers_batch(self, publishers: dict[str, dict[str, str | set[str]]]) -> None: 

181 pipe = self._r.pipeline() 

182 for member_id, data in publishers.items(): 

183 member_key = f"{self.MEMBER_PREFIX}{member_id}" 

184 prefixes_list = list(data["prefixes"]) 

185 pipe.set(member_key, json.dumps({"name": data["name"], "prefixes": prefixes_list})) 

186 for prefix in prefixes_list: 

187 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}" 

188 pipe.set(prefix_key, member_id) 

189 pipe.execute() 

190 

191 def has_data(self) -> bool: 

192 return cast(int, self._r.dbsize()) > 0 

193 

194 def clear(self) -> None: 

195 self._r.flushdb() 

196 

197 

198def load_publishers_to_redis( 

199 publishers_filepath: str, 

200 publishers_redis: PublishersRedis, 

201 batch_size: int = 5000, 

202) -> None: 

203 if not exists(publishers_filepath): 

204 return 

205 

206 batch: dict[str, dict[str, str | set[str]]] = {} 

207 count = 0 

208 

209 with open(publishers_filepath, 'r', encoding='utf-8') as f: 

210 reader = DictReader(f) 

211 for row in reader: 

212 pub_id = row['id'] 

213 if pub_id not in batch: 

214 batch[pub_id] = {'name': row['name'], 'prefixes': set()} 

215 prefixes = batch[pub_id]['prefixes'] 

216 prefixes.add(row['prefix']) # type: ignore[union-attr] 

217 count += 1 

218 

219 if count >= batch_size: 

220 publishers_redis.set_publishers_batch(batch) 

221 batch = {} 

222 count = 0 

223 

224 if batch: 

225 publishers_redis.set_publishers_batch(batch)