Coverage for oc_meta / run / meta / preprocess_input.py: 100%

161 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-21 09:24 +0000

1#!/usr/bin/python 

2 

3# SPDX-FileCopyrightText: 2025-2026 Arcangelo Massari <arcangelo.massari@unibo.it> 

4# 

5# SPDX-License-Identifier: ISC 

6 

7from __future__ import annotations 

8 

9import argparse 

10import multiprocessing 

11import os 

12from concurrent.futures import ProcessPoolExecutor, as_completed 

13from dataclasses import dataclass 

14from typing import Callable, List 

15 

16import redis 

17from rich.table import Table 

18from rich_argparse import RichHelpFormatter 

19 

20from oc_meta.constants import QLEVER_BATCH_SIZE, QLEVER_MAX_WORKERS 

21from oc_meta.lib.console import console, create_progress 

22from oc_meta.lib.file_manager import get_csv_data, write_csv 

23from oc_meta.lib.sparql import run_queries_parallel 

24from oc_meta.run.meta.merge_csv import resolve_output_path 

25 

26DATACITE_PREFIX = "http://purl.org/spar/datacite/" 

27 

28 

29@dataclass 

30class ProcessingStats: 

31 total_rows: int = 0 

32 duplicate_rows: int = 0 

33 existing_ids_rows: int = 0 

34 processed_rows: int = 0 

35 

36 

37@dataclass 

38class FileResult: 

39 file_path: str 

40 rows: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] 

41 stats: ProcessingStats 

42 

43 

44def create_redis_connection(host: str, port: int, db: int = 10) -> redis.Redis: 

45 return redis.Redis(host=host, port=port, db=db, decode_responses=True) 

46 

47 

48def check_ids_existence_batch( 

49 rows: list[dict[str, str]], redis_client: redis.Redis 

50) -> list[bool]: 

51 row_id_lists: list[list[str]] = [] 

52 for row in rows: 

53 ids_str = row["id"] 

54 row_id_lists.append(ids_str.split() if ids_str else []) 

55 

56 pipe = redis_client.pipeline() 

57 for id_list in row_id_lists: 

58 for id_str in id_list: 

59 pipe.exists(id_str) 

60 

61 results = pipe.execute() 

62 

63 row_results: list[bool] = [] 

64 idx = 0 

65 for id_list in row_id_lists: 

66 if not id_list: 

67 row_results.append(False) 

68 else: 

69 all_exist = True 

70 for _ in id_list: 

71 if not results[idx]: 

72 all_exist = False 

73 idx += 1 

74 row_results.append(all_exist) 

75 

76 return row_results 

77 

78 

79def check_ids_sparql( 

80 identifiers: set[str], 

81 endpoint_url: str, 

82 workers: int = QLEVER_MAX_WORKERS, 

83 progress_callback: Callable[[int], None] | None = None, 

84) -> set[str]: 

85 if not identifiers: 

86 return set() 

87 

88 id_list = sorted(identifiers) 

89 batch_queries: list[str] = [] 

90 batch_sizes: list[int] = [] 

91 

92 for i in range(0, len(id_list), QLEVER_BATCH_SIZE): 

93 batch = id_list[i:i + QLEVER_BATCH_SIZE] 

94 values_entries = [] 

95 for id_str in batch: 

96 schema, value = id_str.split(":", 1) 

97 escaped_value = value.replace('\\', '\\\\').replace('"', '\\"') 

98 values_entries.append( 

99 '("{}"^^xsd:string datacite:{})'.format(escaped_value, schema) 

100 ) 

101 

102 query = ( 

103 "PREFIX datacite: <http://purl.org/spar/datacite/>\n" 

104 "PREFIX literal: <http://www.essepuntato.it/2010/06/literalreification/>\n" 

105 "PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>\n" 

106 "SELECT ?val ?scheme WHERE {{\n" 

107 " VALUES (?val ?scheme) {{ {} }}\n" 

108 " ?id literal:hasLiteralValue ?val ;\n" 

109 " datacite:usesIdentifierScheme ?scheme .\n" 

110 "}}" 

111 ).format(" ".join(values_entries)) 

112 batch_queries.append(query) 

113 batch_sizes.append(len(batch)) 

114 

115 all_bindings = run_queries_parallel( 

116 endpoint_url, batch_queries, batch_sizes, workers, progress_callback 

117 ) 

118 

119 found: set[str] = set() 

120 for bindings in all_bindings: 

121 for result in bindings: 

122 val = result["val"]["value"] 

123 scheme_uri = result["scheme"]["value"] 

124 scheme = scheme_uri[len(DATACITE_PREFIX):] if scheme_uri.startswith(DATACITE_PREFIX) else scheme_uri 

125 found.add("{}:{}".format(scheme, val)) 

126 

127 return found 

128 

129 

130def get_csv_files(directory: str) -> List[str]: 

131 if not os.path.isdir(directory): 

132 raise ValueError( 

133 "The specified path '{}' is not a directory".format(directory) 

134 ) 

135 

136 return [ 

137 os.path.join(directory, f) 

138 for f in os.listdir(directory) 

139 if f.endswith(".csv") and os.path.isfile(os.path.join(directory, f)) 

140 ] 

141 

142 

143def collect_rows_from_file(file_path: str) -> FileResult: 

144 data = get_csv_data(file_path, clean_data=False) 

145 stats = ProcessingStats() 

146 stats.total_rows = len(data) 

147 valid_rows: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] = [] 

148 for row in data: 

149 row_hash = tuple(sorted(row.items())) 

150 valid_rows.append((row_hash, row)) 

151 return FileResult(file_path=file_path, rows=valid_rows, stats=stats) 

152 

153 

154def filter_existing_ids_from_file( 

155 file_path: str, redis_host: str, redis_port: int, redis_db: int 

156) -> FileResult: 

157 redis_client = create_redis_connection(redis_host, redis_port, redis_db) 

158 data = get_csv_data(file_path, clean_data=False) 

159 

160 stats = ProcessingStats() 

161 stats.total_rows = len(data) 

162 

163 existence_results = check_ids_existence_batch(data, redis_client) 

164 

165 valid_rows: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] = [] 

166 for row, exists in zip(data, existence_results): 

167 if exists: 

168 stats.existing_ids_rows += 1 

169 else: 

170 row_hash = tuple(sorted(row.items())) 

171 valid_rows.append((row_hash, row)) 

172 

173 return FileResult(file_path=file_path, rows=valid_rows, stats=stats) 

174 

175 

176def filter_sparql_results( 

177 results: list[FileResult], 

178 found_ids: set[str], 

179) -> None: 

180 for result in results: 

181 filtered: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] = [] 

182 for row_hash, row in result.rows: 

183 ids_str = row["id"] 

184 if ids_str: 

185 row_ids = ids_str.split() 

186 if row_ids and all(id_str in found_ids for id_str in row_ids): 

187 result.stats.existing_ids_rows += 1 

188 continue 

189 filtered.append((row_hash, row)) 

190 result.rows = filtered 

191 

192 

193def deduplicate_and_write( 

194 results: list[FileResult], 

195 output_path: str, 

196 rows_per_file: int | None = None, 

197) -> ProcessingStats: 

198 seen_rows: set[tuple[tuple[str, str], ...]] = set() 

199 rows_to_write: list[dict[str, str]] = [] 

200 file_num = 0 

201 

202 total_stats = ProcessingStats() 

203 

204 with create_progress() as progress: 

205 task = progress.add_task("Deduplicating and writing", total=len(results)) 

206 

207 for result in results: 

208 total_stats.total_rows += result.stats.total_rows 

209 total_stats.existing_ids_rows += result.stats.existing_ids_rows 

210 

211 for row_hash, row in result.rows: 

212 if row_hash in seen_rows: 

213 total_stats.duplicate_rows += 1 

214 continue 

215 

216 seen_rows.add(row_hash) 

217 total_stats.processed_rows += 1 

218 rows_to_write.append(row) 

219 

220 if rows_per_file and len(rows_to_write) >= rows_per_file: 

221 output_file = os.path.join(output_path, "{}.csv".format(file_num)) 

222 write_csv(output_file, rows_to_write) 

223 file_num += 1 

224 rows_to_write = [] 

225 

226 progress.advance(task) 

227 

228 if rows_to_write: 

229 if rows_per_file: 

230 output_file = os.path.join(output_path, "{}.csv".format(file_num)) 

231 else: 

232 output_file = resolve_output_path(output_path) 

233 write_csv(output_file, rows_to_write) 

234 

235 return total_stats 

236 

237 

238def print_processing_report(stats: ProcessingStats, num_files: int) -> None: 

239 table = Table(title="Processing Report") 

240 table.add_column("Metric", style="cyan") 

241 table.add_column("Value", style="green") 

242 

243 table.add_row("Total input files processed", str(num_files)) 

244 table.add_row("Total input rows", str(stats.total_rows)) 

245 table.add_row("Rows discarded (duplicates)", str(stats.duplicate_rows)) 

246 table.add_row("Rows discarded (existing IDs)", str(stats.existing_ids_rows)) 

247 table.add_row("Rows written to output", str(stats.processed_rows)) 

248 

249 if stats.total_rows > 0: 

250 duplicate_percent = (stats.duplicate_rows / stats.total_rows) * 100 

251 existing_percent = (stats.existing_ids_rows / stats.total_rows) * 100 

252 processed_percent = (stats.processed_rows / stats.total_rows) * 100 

253 

254 table.add_row("", "") 

255 table.add_row("Duplicate rows %", "{:.1f}%".format(duplicate_percent)) 

256 table.add_row("Existing IDs %", "{:.1f}%".format(existing_percent)) 

257 table.add_row("Processed rows %", "{:.1f}%".format(processed_percent)) 

258 

259 console.print(table) 

260 

261 

262def main(): # pragma: no cover 

263 parser = argparse.ArgumentParser( 

264 description="Process CSV files and check IDs against Redis or SPARQL endpoint", 

265 formatter_class=RichHelpFormatter, 

266 ) 

267 parser.add_argument("input_dir", help="Directory containing input CSV files") 

268 parser.add_argument( 

269 "output", 

270 help="Output path: directory for split files, or path ending in .csv for single file", 

271 ) 

272 

273 output_group = parser.add_mutually_exclusive_group() 

274 output_group.add_argument( 

275 "--rows-per-file", 

276 type=int, 

277 default=None, 

278 help="Split output into files of N rows each (default: 3000)", 

279 ) 

280 output_group.add_argument( 

281 "--single-file", 

282 action="store_true", 

283 help="Write all output rows to a single CSV file", 

284 ) 

285 

286 parser.add_argument( 

287 "--redis-host", default="localhost", help="Redis host (default: localhost)" 

288 ) 

289 parser.add_argument( 

290 "--redis-port", type=int, help="Redis port (required for Redis mode)" 

291 ) 

292 parser.add_argument( 

293 "--redis-db", 

294 type=int, 

295 default=10, 

296 help="Redis database number (default: 10)", 

297 ) 

298 parser.add_argument( 

299 "--sparql-endpoint", 

300 help="SPARQL endpoint URL for ID existence checking (alternative to Redis)", 

301 ) 

302 parser.add_argument( 

303 "--workers", 

304 type=int, 

305 default=4, 

306 help="Number of parallel workers (default: 4)", 

307 ) 

308 args = parser.parse_args() 

309 

310 if not args.sparql_endpoint and args.redis_port is None: 

311 parser.error("either --redis-port or --sparql-endpoint is required") 

312 

313 if args.single_file: 

314 rows_per_file = None 

315 elif args.rows_per_file is not None: 

316 rows_per_file = args.rows_per_file 

317 else: 

318 rows_per_file = 3000 

319 

320 if rows_per_file: 

321 os.makedirs(args.output, exist_ok=True) 

322 

323 csv_files = get_csv_files(args.input_dir) 

324 if not csv_files: 

325 console.print( 

326 "[red]No CSV files found in directory: {}[/red]".format(args.input_dir) 

327 ) 

328 return 1 

329 

330 use_sparql = args.sparql_endpoint is not None 

331 

332 console.print( 

333 "Found [green]{}[/green] CSV files to process with [green]{}[/green] workers ({})".format( 

334 len(csv_files), args.workers, "SPARQL" if use_sparql else "Redis" 

335 ) 

336 ) 

337 

338 file_order = {f: i for i, f in enumerate(csv_files)} 

339 

340 if use_sparql: 

341 results: list[FileResult] = [] 

342 with create_progress() as progress: 

343 task = progress.add_task("Reading CSV files", total=len(csv_files)) 

344 with ProcessPoolExecutor( 

345 max_workers=args.workers, 

346 mp_context=multiprocessing.get_context('forkserver') 

347 ) as executor: 

348 futures = { 

349 executor.submit(collect_rows_from_file, f): f 

350 for f in csv_files 

351 } 

352 for future in as_completed(futures): 

353 results.append(future.result()) 

354 progress.advance(task) 

355 

356 results.sort(key=lambda r: file_order[r.file_path]) 

357 

358 all_ids: set[str] = set() 

359 for result in results: 

360 for _hash, row in result.rows: 

361 ids_str = row["id"] 

362 if ids_str: 

363 all_ids.update(ids_str.split()) 

364 

365 if all_ids: 

366 console.print( 

367 "Checking [green]{}[/green] unique identifiers against SPARQL endpoint".format(len(all_ids)) 

368 ) 

369 with create_progress() as progress: 

370 task = progress.add_task("Querying SPARQL", total=len(all_ids)) 

371 

372 def on_batch(batch_size: int) -> None: 

373 progress.advance(task, batch_size) 

374 

375 found_ids = check_ids_sparql( 

376 all_ids, args.sparql_endpoint, args.workers, on_batch 

377 ) 

378 else: 

379 found_ids = set() 

380 

381 filter_sparql_results(results, found_ids) 

382 else: 

383 results = [] 

384 with create_progress() as progress: 

385 task = progress.add_task("Filtering existing IDs", total=len(csv_files)) 

386 with ProcessPoolExecutor( 

387 max_workers=args.workers, 

388 mp_context=multiprocessing.get_context('forkserver') 

389 ) as executor: 

390 futures = { 

391 executor.submit( 

392 filter_existing_ids_from_file, 

393 csv_file, 

394 args.redis_host, 

395 args.redis_port, 

396 args.redis_db, 

397 ): csv_file 

398 for csv_file in csv_files 

399 } 

400 for future in as_completed(futures): 

401 results.append(future.result()) 

402 progress.advance(task) 

403 

404 results.sort(key=lambda r: file_order[r.file_path]) 

405 

406 total_stats = deduplicate_and_write(results, args.output, rows_per_file) 

407 

408 print_processing_report(total_stats, len(csv_files)) 

409 

410 return 0 

411 

412 

413if __name__ == "__main__": # pragma: no cover 

414 main()