Coverage for oc_meta / run / meta / preprocess_input.py: 100%
161 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-21 09:24 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-21 09:24 +0000
1#!/usr/bin/python
3# SPDX-FileCopyrightText: 2025-2026 Arcangelo Massari <arcangelo.massari@unibo.it>
4#
5# SPDX-License-Identifier: ISC
7from __future__ import annotations
9import argparse
10import multiprocessing
11import os
12from concurrent.futures import ProcessPoolExecutor, as_completed
13from dataclasses import dataclass
14from typing import Callable, List
16import redis
17from rich.table import Table
18from rich_argparse import RichHelpFormatter
20from oc_meta.constants import QLEVER_BATCH_SIZE, QLEVER_MAX_WORKERS
21from oc_meta.lib.console import console, create_progress
22from oc_meta.lib.file_manager import get_csv_data, write_csv
23from oc_meta.lib.sparql import run_queries_parallel
24from oc_meta.run.meta.merge_csv import resolve_output_path
26DATACITE_PREFIX = "http://purl.org/spar/datacite/"
29@dataclass
30class ProcessingStats:
31 total_rows: int = 0
32 duplicate_rows: int = 0
33 existing_ids_rows: int = 0
34 processed_rows: int = 0
37@dataclass
38class FileResult:
39 file_path: str
40 rows: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]]
41 stats: ProcessingStats
44def create_redis_connection(host: str, port: int, db: int = 10) -> redis.Redis:
45 return redis.Redis(host=host, port=port, db=db, decode_responses=True)
48def check_ids_existence_batch(
49 rows: list[dict[str, str]], redis_client: redis.Redis
50) -> list[bool]:
51 row_id_lists: list[list[str]] = []
52 for row in rows:
53 ids_str = row["id"]
54 row_id_lists.append(ids_str.split() if ids_str else [])
56 pipe = redis_client.pipeline()
57 for id_list in row_id_lists:
58 for id_str in id_list:
59 pipe.exists(id_str)
61 results = pipe.execute()
63 row_results: list[bool] = []
64 idx = 0
65 for id_list in row_id_lists:
66 if not id_list:
67 row_results.append(False)
68 else:
69 all_exist = True
70 for _ in id_list:
71 if not results[idx]:
72 all_exist = False
73 idx += 1
74 row_results.append(all_exist)
76 return row_results
79def check_ids_sparql(
80 identifiers: set[str],
81 endpoint_url: str,
82 workers: int = QLEVER_MAX_WORKERS,
83 progress_callback: Callable[[int], None] | None = None,
84) -> set[str]:
85 if not identifiers:
86 return set()
88 id_list = sorted(identifiers)
89 batch_queries: list[str] = []
90 batch_sizes: list[int] = []
92 for i in range(0, len(id_list), QLEVER_BATCH_SIZE):
93 batch = id_list[i:i + QLEVER_BATCH_SIZE]
94 values_entries = []
95 for id_str in batch:
96 schema, value = id_str.split(":", 1)
97 escaped_value = value.replace('\\', '\\\\').replace('"', '\\"')
98 values_entries.append(
99 '("{}"^^xsd:string datacite:{})'.format(escaped_value, schema)
100 )
102 query = (
103 "PREFIX datacite: <http://purl.org/spar/datacite/>\n"
104 "PREFIX literal: <http://www.essepuntato.it/2010/06/literalreification/>\n"
105 "PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>\n"
106 "SELECT ?val ?scheme WHERE {{\n"
107 " VALUES (?val ?scheme) {{ {} }}\n"
108 " ?id literal:hasLiteralValue ?val ;\n"
109 " datacite:usesIdentifierScheme ?scheme .\n"
110 "}}"
111 ).format(" ".join(values_entries))
112 batch_queries.append(query)
113 batch_sizes.append(len(batch))
115 all_bindings = run_queries_parallel(
116 endpoint_url, batch_queries, batch_sizes, workers, progress_callback
117 )
119 found: set[str] = set()
120 for bindings in all_bindings:
121 for result in bindings:
122 val = result["val"]["value"]
123 scheme_uri = result["scheme"]["value"]
124 scheme = scheme_uri[len(DATACITE_PREFIX):] if scheme_uri.startswith(DATACITE_PREFIX) else scheme_uri
125 found.add("{}:{}".format(scheme, val))
127 return found
130def get_csv_files(directory: str) -> List[str]:
131 if not os.path.isdir(directory):
132 raise ValueError(
133 "The specified path '{}' is not a directory".format(directory)
134 )
136 return [
137 os.path.join(directory, f)
138 for f in os.listdir(directory)
139 if f.endswith(".csv") and os.path.isfile(os.path.join(directory, f))
140 ]
143def collect_rows_from_file(file_path: str) -> FileResult:
144 data = get_csv_data(file_path, clean_data=False)
145 stats = ProcessingStats()
146 stats.total_rows = len(data)
147 valid_rows: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] = []
148 for row in data:
149 row_hash = tuple(sorted(row.items()))
150 valid_rows.append((row_hash, row))
151 return FileResult(file_path=file_path, rows=valid_rows, stats=stats)
154def filter_existing_ids_from_file(
155 file_path: str, redis_host: str, redis_port: int, redis_db: int
156) -> FileResult:
157 redis_client = create_redis_connection(redis_host, redis_port, redis_db)
158 data = get_csv_data(file_path, clean_data=False)
160 stats = ProcessingStats()
161 stats.total_rows = len(data)
163 existence_results = check_ids_existence_batch(data, redis_client)
165 valid_rows: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] = []
166 for row, exists in zip(data, existence_results):
167 if exists:
168 stats.existing_ids_rows += 1
169 else:
170 row_hash = tuple(sorted(row.items()))
171 valid_rows.append((row_hash, row))
173 return FileResult(file_path=file_path, rows=valid_rows, stats=stats)
176def filter_sparql_results(
177 results: list[FileResult],
178 found_ids: set[str],
179) -> None:
180 for result in results:
181 filtered: list[tuple[tuple[tuple[str, str], ...], dict[str, str]]] = []
182 for row_hash, row in result.rows:
183 ids_str = row["id"]
184 if ids_str:
185 row_ids = ids_str.split()
186 if row_ids and all(id_str in found_ids for id_str in row_ids):
187 result.stats.existing_ids_rows += 1
188 continue
189 filtered.append((row_hash, row))
190 result.rows = filtered
193def deduplicate_and_write(
194 results: list[FileResult],
195 output_path: str,
196 rows_per_file: int | None = None,
197) -> ProcessingStats:
198 seen_rows: set[tuple[tuple[str, str], ...]] = set()
199 rows_to_write: list[dict[str, str]] = []
200 file_num = 0
202 total_stats = ProcessingStats()
204 with create_progress() as progress:
205 task = progress.add_task("Deduplicating and writing", total=len(results))
207 for result in results:
208 total_stats.total_rows += result.stats.total_rows
209 total_stats.existing_ids_rows += result.stats.existing_ids_rows
211 for row_hash, row in result.rows:
212 if row_hash in seen_rows:
213 total_stats.duplicate_rows += 1
214 continue
216 seen_rows.add(row_hash)
217 total_stats.processed_rows += 1
218 rows_to_write.append(row)
220 if rows_per_file and len(rows_to_write) >= rows_per_file:
221 output_file = os.path.join(output_path, "{}.csv".format(file_num))
222 write_csv(output_file, rows_to_write)
223 file_num += 1
224 rows_to_write = []
226 progress.advance(task)
228 if rows_to_write:
229 if rows_per_file:
230 output_file = os.path.join(output_path, "{}.csv".format(file_num))
231 else:
232 output_file = resolve_output_path(output_path)
233 write_csv(output_file, rows_to_write)
235 return total_stats
238def print_processing_report(stats: ProcessingStats, num_files: int) -> None:
239 table = Table(title="Processing Report")
240 table.add_column("Metric", style="cyan")
241 table.add_column("Value", style="green")
243 table.add_row("Total input files processed", str(num_files))
244 table.add_row("Total input rows", str(stats.total_rows))
245 table.add_row("Rows discarded (duplicates)", str(stats.duplicate_rows))
246 table.add_row("Rows discarded (existing IDs)", str(stats.existing_ids_rows))
247 table.add_row("Rows written to output", str(stats.processed_rows))
249 if stats.total_rows > 0:
250 duplicate_percent = (stats.duplicate_rows / stats.total_rows) * 100
251 existing_percent = (stats.existing_ids_rows / stats.total_rows) * 100
252 processed_percent = (stats.processed_rows / stats.total_rows) * 100
254 table.add_row("", "")
255 table.add_row("Duplicate rows %", "{:.1f}%".format(duplicate_percent))
256 table.add_row("Existing IDs %", "{:.1f}%".format(existing_percent))
257 table.add_row("Processed rows %", "{:.1f}%".format(processed_percent))
259 console.print(table)
262def main(): # pragma: no cover
263 parser = argparse.ArgumentParser(
264 description="Process CSV files and check IDs against Redis or SPARQL endpoint",
265 formatter_class=RichHelpFormatter,
266 )
267 parser.add_argument("input_dir", help="Directory containing input CSV files")
268 parser.add_argument(
269 "output",
270 help="Output path: directory for split files, or path ending in .csv for single file",
271 )
273 output_group = parser.add_mutually_exclusive_group()
274 output_group.add_argument(
275 "--rows-per-file",
276 type=int,
277 default=None,
278 help="Split output into files of N rows each (default: 3000)",
279 )
280 output_group.add_argument(
281 "--single-file",
282 action="store_true",
283 help="Write all output rows to a single CSV file",
284 )
286 parser.add_argument(
287 "--redis-host", default="localhost", help="Redis host (default: localhost)"
288 )
289 parser.add_argument(
290 "--redis-port", type=int, help="Redis port (required for Redis mode)"
291 )
292 parser.add_argument(
293 "--redis-db",
294 type=int,
295 default=10,
296 help="Redis database number (default: 10)",
297 )
298 parser.add_argument(
299 "--sparql-endpoint",
300 help="SPARQL endpoint URL for ID existence checking (alternative to Redis)",
301 )
302 parser.add_argument(
303 "--workers",
304 type=int,
305 default=4,
306 help="Number of parallel workers (default: 4)",
307 )
308 args = parser.parse_args()
310 if not args.sparql_endpoint and args.redis_port is None:
311 parser.error("either --redis-port or --sparql-endpoint is required")
313 if args.single_file:
314 rows_per_file = None
315 elif args.rows_per_file is not None:
316 rows_per_file = args.rows_per_file
317 else:
318 rows_per_file = 3000
320 if rows_per_file:
321 os.makedirs(args.output, exist_ok=True)
323 csv_files = get_csv_files(args.input_dir)
324 if not csv_files:
325 console.print(
326 "[red]No CSV files found in directory: {}[/red]".format(args.input_dir)
327 )
328 return 1
330 use_sparql = args.sparql_endpoint is not None
332 console.print(
333 "Found [green]{}[/green] CSV files to process with [green]{}[/green] workers ({})".format(
334 len(csv_files), args.workers, "SPARQL" if use_sparql else "Redis"
335 )
336 )
338 file_order = {f: i for i, f in enumerate(csv_files)}
340 if use_sparql:
341 results: list[FileResult] = []
342 with create_progress() as progress:
343 task = progress.add_task("Reading CSV files", total=len(csv_files))
344 with ProcessPoolExecutor(
345 max_workers=args.workers,
346 mp_context=multiprocessing.get_context('forkserver')
347 ) as executor:
348 futures = {
349 executor.submit(collect_rows_from_file, f): f
350 for f in csv_files
351 }
352 for future in as_completed(futures):
353 results.append(future.result())
354 progress.advance(task)
356 results.sort(key=lambda r: file_order[r.file_path])
358 all_ids: set[str] = set()
359 for result in results:
360 for _hash, row in result.rows:
361 ids_str = row["id"]
362 if ids_str:
363 all_ids.update(ids_str.split())
365 if all_ids:
366 console.print(
367 "Checking [green]{}[/green] unique identifiers against SPARQL endpoint".format(len(all_ids))
368 )
369 with create_progress() as progress:
370 task = progress.add_task("Querying SPARQL", total=len(all_ids))
372 def on_batch(batch_size: int) -> None:
373 progress.advance(task, batch_size)
375 found_ids = check_ids_sparql(
376 all_ids, args.sparql_endpoint, args.workers, on_batch
377 )
378 else:
379 found_ids = set()
381 filter_sparql_results(results, found_ids)
382 else:
383 results = []
384 with create_progress() as progress:
385 task = progress.add_task("Filtering existing IDs", total=len(csv_files))
386 with ProcessPoolExecutor(
387 max_workers=args.workers,
388 mp_context=multiprocessing.get_context('forkserver')
389 ) as executor:
390 futures = {
391 executor.submit(
392 filter_existing_ids_from_file,
393 csv_file,
394 args.redis_host,
395 args.redis_port,
396 args.redis_db,
397 ): csv_file
398 for csv_file in csv_files
399 }
400 for future in as_completed(futures):
401 results.append(future.result())
402 progress.advance(task)
404 results.sort(key=lambda r: file_order[r.file_path])
406 total_stats = deduplicate_and_write(results, args.output, rows_per_file)
408 print_processing_report(total_stats, len(csv_files))
410 return 0
413if __name__ == "__main__": # pragma: no cover
414 main()