Coverage for oc_meta / run / meta / preprocess_input.py: 38%

173 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-03 17:25 +0000

1#!/usr/bin/python 

2# -*- coding: utf-8 -*- 

3# Copyright (c) 2025 Arcangelo <arcangelo.massari@unibo.it> 

4# 

5# Permission to use, copy, modify, and/or distribute this software for any purpose 

6# with or without fee is hereby granted, provided that the above copyright notice 

7# and this permission notice appear in all copies. 

8# 

9# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH 

10# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND 

11# FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, 

12# OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, 

13# DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 

14# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS 

15# SOFTWARE. 

16 

17from __future__ import annotations 

18 

19import argparse 

20import csv 

21import os 

22from typing import List, Union 

23 

24import redis 

25from rich_argparse import RichHelpFormatter 

26from sparqlite import SPARQLClient 

27from tqdm import tqdm 

28 

29 

30class ProcessingStats(object): 

31 """Class to track processing statistics""" 

32 def __init__(self): 

33 self.total_rows = 0 

34 self.duplicate_rows = 0 

35 self.existing_ids_rows = 0 

36 self.processed_rows = 0 

37 

38def create_redis_connection(host: str, port: int, db: int = 10) -> redis.Redis: 

39 """Create and return a Redis connection.""" 

40 return redis.Redis( 

41 host=host, 

42 port=port, 

43 db=db, 

44 decode_responses=True 

45 ) 

46 

47def check_ids_existence_redis(ids: str, redis_client: redis.Redis) -> bool: 

48 """ 

49 Check if all IDs in the input string exist in Redis. 

50 Returns True if all IDs exist, False otherwise. 

51 """ 

52 if not ids: 

53 return False 

54 

55 id_list = ids.split() 

56 

57 for id_str in id_list: 

58 if not redis_client.get(id_str): 

59 return False 

60 

61 return True 

62 

63def check_ids_existence_sparql(ids: str, sparql_endpoint: str) -> bool: 

64 """ 

65 Check if all IDs in the input string exist in the SPARQL endpoint. 

66 Returns True if all IDs exist, False otherwise. 

67 """ 

68 if not ids: 

69 return False 

70 

71 id_list = ids.split() 

72 

73 with SPARQLClient(sparql_endpoint, max_retries=5, backoff_factor=5, timeout=3600) as client: 

74 for id_str in id_list: 

75 escaped_id = id_str.replace("'", "\\'").replace('"', '\\"') 

76 

77 parts = escaped_id.split(":", 1) 

78 scheme = parts[0] 

79 value = parts[1] 

80 

81 query = f""" 

82 PREFIX datacite: <http://purl.org/spar/datacite/> 

83 PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> 

84 PREFIX literal: <http://www.essepuntato.it/2010/06/literalreification/> 

85 

86 ASK {{ 

87 ?identifier datacite:usesIdentifierScheme datacite:{scheme} ; 

88 literal:hasLiteralValue ?value . 

89 FILTER( 

90 ?value = "{value}" || 

91 ?value = "{value}"^^xsd:string 

92 ) 

93 }} 

94 """ 

95 

96 try: 

97 results = client.query(query) 

98 if not results.get('boolean', False): 

99 return False 

100 except Exception: 

101 return False 

102 

103 return True 

104 

105def check_ids_existence(ids: str, storage_type: str, storage_reference: Union[redis.Redis, str, None]) -> bool: 

106 """ 

107 Check if all IDs in the input string exist in the storage. 

108 

109 Args: 

110 ids: String of space-separated IDs to check 

111 storage_type: Either 'redis', 'sparql', or None to skip checking 

112 storage_reference: Redis client, SPARQL endpoint URL, or None 

113 

114 Returns: 

115 True if all IDs exist, False otherwise, or False if storage_type is None 

116 """ 

117 if storage_type is None: 

118 return False 

119 elif storage_type == 'redis': 

120 assert isinstance(storage_reference, redis.Redis) 

121 return check_ids_existence_redis(ids, storage_reference) 

122 elif storage_type == 'sparql': 

123 assert isinstance(storage_reference, str) 

124 return check_ids_existence_sparql(ids, storage_reference) 

125 else: 

126 raise ValueError(f"Invalid storage type: {storage_type}. Must be 'redis', 'sparql', or None") 

127 

128def get_csv_files(directory: str) -> List[str]: 

129 """Get all CSV files in the specified directory (first level only).""" 

130 if not os.path.isdir(directory): 

131 raise ValueError("The specified path '{}' is not a directory".format(directory)) 

132 

133 return [ 

134 os.path.join(directory, f) 

135 for f in os.listdir(directory) 

136 if f.endswith('.csv') and os.path.isfile(os.path.join(directory, f)) 

137 ] 

138 

139def process_csv_file(input_file, output_dir, current_file_num, rows_per_file=3000, 

140 storage_type='redis', storage_reference=None, redis_db=10, 

141 redis_host='localhost', redis_port=6379, seen_rows=None, pending_rows=None): 

142 """ 

143 Process a single CSV file and write non-duplicate rows with non-existing IDs to output files. 

144 

145 Args: 

146 input_file: Path to the input CSV file 

147 output_dir: Directory where output files will be written 

148 current_file_num: Number to use for the next output file 

149 rows_per_file: Number of rows per output file 

150 storage_type: Type of storage to check IDs against ('redis', 'sparql', or None to skip) 

151 storage_reference: Redis client or SPARQL endpoint URL. If None and storage_type is 'redis', 

152 a new connection will be created 

153 redis_db: Redis database number to use if storage_type is 'redis' and storage_reference is None 

154 redis_host: Redis host if storage_type is 'redis' and storage_reference is None 

155 redis_port: Redis port if storage_type is 'redis' and storage_reference is None 

156 seen_rows: Set of previously seen rows (for cross-file deduplication) 

157 pending_rows: List of rows waiting to be written (for cross-file batching) 

158 

159 Returns: 

160 Tuple of (next file number, processing statistics, pending rows) 

161 """ 

162 rows_to_write = pending_rows if pending_rows is not None else [] 

163 file_num = current_file_num 

164 seen_rows = seen_rows if seen_rows is not None else set() 

165 

166 if storage_type == 'redis': 

167 storage_ref = storage_reference if storage_reference is not None else create_redis_connection(redis_host, redis_port, redis_db) 

168 else: 

169 storage_ref = storage_reference 

170 

171 stats = ProcessingStats() 

172 

173 while True: 

174 try: 

175 with open(input_file, 'r', encoding='utf-8') as f: 

176 reader = csv.DictReader(f) 

177 fieldnames = reader.fieldnames 

178 assert fieldnames is not None 

179 

180 for row in reader: 

181 stats.total_rows += 1 

182 row_hash = frozenset(row.items()) 

183 

184 if row_hash in seen_rows: 

185 stats.duplicate_rows += 1 

186 continue 

187 

188 seen_rows.add(row_hash) 

189 

190 if check_ids_existence(row['id'], storage_type, storage_ref): 

191 stats.existing_ids_rows += 1 

192 continue 

193 

194 stats.processed_rows += 1 

195 rows_to_write.append(row) 

196 

197 if len(rows_to_write) >= rows_per_file: 

198 output_file = os.path.join(output_dir, "{}.csv".format(file_num)) 

199 with open(output_file, 'w', encoding='utf-8', newline='') as out_f: 

200 writer = csv.DictWriter(out_f, fieldnames=fieldnames) 

201 writer.writeheader() 

202 writer.writerows(rows_to_write) 

203 file_num += 1 

204 rows_to_write = [] 

205 break 

206 

207 except csv.Error as e: 

208 if "field larger than field limit" in str(e): 

209 csv.field_size_limit(int(csv.field_size_limit() * 2)) 

210 else: 

211 raise e 

212 

213 return file_num, stats, rows_to_write 

214 

215def print_processing_report(all_stats: List[ProcessingStats], input_files: List[str], storage_type: str) -> None: 

216 """Print a detailed report of the processing statistics.""" 

217 total_stats = ProcessingStats() 

218 for stats in all_stats: 

219 total_stats.total_rows += stats.total_rows 

220 total_stats.duplicate_rows += stats.duplicate_rows 

221 total_stats.existing_ids_rows += stats.existing_ids_rows 

222 total_stats.processed_rows += stats.processed_rows 

223 

224 print("\nProcessing Report:") 

225 print("=" * 50) 

226 if storage_type: 

227 print("Storage type used: {}".format(storage_type.upper())) 

228 else: 

229 print("Storage type used: None (ID checking skipped)") 

230 print("Total input files processed: {}".format(len(input_files))) 

231 print("Total input rows: {}".format(total_stats.total_rows)) 

232 print("Rows discarded (duplicates): {}".format(total_stats.duplicate_rows)) 

233 if storage_type: 

234 print("Rows discarded (existing IDs): {}".format(total_stats.existing_ids_rows)) 

235 print("Rows written to output: {}".format(total_stats.processed_rows)) 

236 

237 if total_stats.total_rows > 0: 

238 duplicate_percent = (total_stats.duplicate_rows / total_stats.total_rows) * 100 

239 processed_percent = (total_stats.processed_rows / total_stats.total_rows) * 100 

240 

241 print("\nPercentages:") 

242 print("Duplicate rows: {:.1f}%".format(duplicate_percent)) 

243 if storage_type: 

244 existing_percent = (total_stats.existing_ids_rows / total_stats.total_rows) * 100 

245 print("Existing IDs: {:.1f}%".format(existing_percent)) 

246 print("Processed rows: {:.1f}%".format(processed_percent)) 

247 

248def main(): 

249 parser = argparse.ArgumentParser( 

250 description="Process CSV files and optionally check IDs against a storage system (Redis or SPARQL)", 

251 formatter_class=RichHelpFormatter, 

252 ) 

253 parser.add_argument( 

254 "input_dir", 

255 help="Directory containing input CSV files" 

256 ) 

257 parser.add_argument( 

258 "output_dir", 

259 help="Directory for output CSV files" 

260 ) 

261 parser.add_argument( 

262 "--rows-per-file", 

263 type=int, 

264 default=3000, 

265 help="Number of rows per output file (default: 3000)" 

266 ) 

267 parser.add_argument( 

268 "--storage-type", 

269 choices=["redis", "sparql"], 

270 help="Storage type to check IDs against (redis or sparql). If not specified, ID checking is skipped" 

271 ) 

272 parser.add_argument( 

273 "--redis-host", 

274 default="localhost", 

275 help="Redis host (default: localhost)" 

276 ) 

277 parser.add_argument( 

278 "--redis-port", 

279 type=int, 

280 default=6379, 

281 help="Redis port (default: 6379)" 

282 ) 

283 parser.add_argument( 

284 "--redis-db", 

285 type=int, 

286 default=10, 

287 help="Redis database number to use if storage type is redis (default: 10)" 

288 ) 

289 parser.add_argument( 

290 "--sparql-endpoint", 

291 help="SPARQL endpoint URL if storage type is sparql" 

292 ) 

293 args = parser.parse_args() 

294 

295 if args.storage_type == "sparql" and not args.sparql_endpoint: 

296 print("Error: --sparql-endpoint is required when --storage-type is sparql") 

297 return 1 

298 

299 os.makedirs(args.output_dir, exist_ok=True) 

300 

301 try: 

302 csv_files = get_csv_files(args.input_dir) 

303 if not csv_files: 

304 print("No CSV files found in directory: {}".format(args.input_dir)) 

305 return 1 

306 

307 print("Found {} CSV files to process".format(len(csv_files))) 

308 

309 storage_reference = None 

310 storage_type = args.storage_type 

311 

312 if storage_type: 

313 print("Using {} for ID existence checking".format(storage_type.upper())) 

314 if storage_type == "redis": 

315 storage_reference = create_redis_connection(args.redis_host, args.redis_port, args.redis_db) 

316 else: 

317 storage_reference = args.sparql_endpoint 

318 else: 

319 print("Skipping ID existence checking") 

320 

321 current_file_num = 0 

322 all_stats = [] 

323 seen_rows = set() 

324 pending_rows = [] 

325 

326 for csv_file in tqdm(csv_files, desc="Processing CSV files"): 

327 current_file_num, stats, pending_rows = process_csv_file( 

328 csv_file, 

329 args.output_dir, 

330 current_file_num, 

331 rows_per_file=args.rows_per_file, 

332 storage_type=storage_type, 

333 storage_reference=storage_reference, 

334 redis_db=args.redis_db, 

335 redis_host=args.redis_host, 

336 redis_port=args.redis_port, 

337 seen_rows=seen_rows, 

338 pending_rows=pending_rows 

339 ) 

340 all_stats.append(stats) 

341 

342 if pending_rows: 

343 output_file = os.path.join(args.output_dir, "{}.csv".format(current_file_num)) 

344 with open(output_file, 'w', encoding='utf-8', newline='') as out_f: 

345 writer = csv.DictWriter(out_f, fieldnames=pending_rows[0].keys()) 

346 writer.writeheader() 

347 writer.writerows(pending_rows) 

348 

349 print_processing_report(all_stats, csv_files, storage_type) 

350 

351 except Exception as e: 

352 print("Error: {}".format(str(e))) 

353 return 1 

354 

355 return 0 

356 

357if __name__ == "__main__": 

358 main()