Coverage for oc_meta / run / migration / extract_subset.py: 100%
83 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-21 09:24 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-21 09:24 +0000
1#!/usr/bin/env python
3# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it>
4#
5# SPDX-License-Identifier: ISC
7# -*- coding: utf-8 -*-
9import argparse
10import gzip
11import sys
12from urllib.parse import urlparse
14import rdflib
15from rdflib.term import Node
16from rich_argparse import RichHelpFormatter
17from sparqlite import SPARQLClient
19CHUNK_SIZE = 20
22def get_subjects_of_class(client: SPARQLClient, class_uri: str, limit: int) -> list[str]:
23 query = f"""
24 SELECT ?s
25 WHERE {{
26 ?s a <{class_uri}> .
27 }}
28 LIMIT {limit}
29 """
30 results = client.query(query)
31 return [result["s"]["value"] for result in results["results"]["bindings"]]
34def load_entities_from_file(entities_file: str) -> list[str]:
35 with open(entities_file, 'r') as f:
36 return [line.strip() for line in f if line.strip()]
39def parse_object(result: dict[str, dict[str, str]]) -> rdflib.URIRef | rdflib.BNode | rdflib.Literal:
40 o_value = result["o"]["value"]
41 o_type = result["o"]["type"]
42 if o_type == 'uri':
43 return rdflib.URIRef(o_value)
44 if o_type == 'bnode':
45 return rdflib.BNode(o_value)
46 if 'datatype' in result["o"]:
47 return rdflib.Literal(o_value, datatype=result["o"]["datatype"])
48 if 'xml:lang' in result["o"]:
49 return rdflib.Literal(o_value, lang=result["o"]["xml:lang"])
50 return rdflib.Literal(o_value)
53def get_triples_for_entities(
54 client: SPARQLClient,
55 entity_uris: list[str],
56 use_graphs: bool,
57) -> list[tuple[rdflib.URIRef, rdflib.URIRef, Node, rdflib.URIRef | None]]:
58 quads: list[tuple[rdflib.URIRef, rdflib.URIRef, Node, rdflib.URIRef | None]] = []
60 for i in range(0, len(entity_uris), CHUNK_SIZE):
61 chunk = entity_uris[i:i + CHUNK_SIZE]
62 values = " ".join(f"<{uri}>" for uri in chunk)
64 if use_graphs:
65 query = f"""
66 SELECT ?s ?p ?o ?g
67 WHERE {{
68 GRAPH ?g {{
69 VALUES ?s {{ {values} }}
70 ?s ?p ?o .
71 }}
72 }}
73 """
74 else:
75 query = f"""
76 SELECT ?s ?p ?o
77 WHERE {{
78 VALUES ?s {{ {values} }}
79 ?s ?p ?o .
80 }}
81 """
83 results = client.query(query)
84 for result in results["results"]["bindings"]:
85 s_term = rdflib.URIRef(result["s"]["value"])
86 p_term = rdflib.URIRef(result["p"]["value"])
87 o_term = parse_object(result)
88 g_term = rdflib.URIRef(result["g"]["value"]) if "g" in result else None
89 quads.append((s_term, p_term, o_term, g_term))
91 return quads
94def extract_subset(
95 endpoint: str,
96 limit: int,
97 output_file: str,
98 compress: bool,
99 max_retries: int = 5,
100 class_uri: str | None = None,
101 entities_file: str | None = None,
102 use_graphs: bool = True,
103) -> tuple[int, str]:
104 with SPARQLClient(endpoint, max_retries=max_retries, backoff_factor=2, timeout=3600) as client:
105 if entities_file:
106 subjects = load_entities_from_file(entities_file)
107 else:
108 assert class_uri is not None
109 subjects = get_subjects_of_class(client, class_uri, limit)
111 processed_entities: set[str] = set()
112 pending_entities = set(subjects)
114 dataset: rdflib.Dataset | None = None
115 graph: rdflib.Graph | None = None
116 if use_graphs:
117 dataset = rdflib.Dataset()
118 else:
119 graph = rdflib.Graph()
121 while pending_entities:
122 batch = sorted(pending_entities - processed_entities)
123 if not batch:
124 break # pragma: no cover
126 processed_entities.update(batch)
127 pending_entities.clear()
129 quads = get_triples_for_entities(client, batch, use_graphs)
131 for s_term, p_term, o_term, g_term in quads:
132 if dataset is not None:
133 named_graph = dataset.graph(g_term)
134 named_graph.add((s_term, p_term, o_term))
135 elif graph is not None:
136 graph.add((s_term, p_term, o_term))
138 if isinstance(o_term, rdflib.URIRef):
139 o_str = str(o_term)
140 if o_str not in processed_entities:
141 pending_entities.add(o_str)
143 store = dataset if dataset is not None else graph
144 assert store is not None
145 output_format = "nquads" if use_graphs else "nt"
146 if compress:
147 if not output_file.endswith('.gz'):
148 output_file = output_file + '.gz'
149 with gzip.open(output_file, 'wb') as f:
150 store.serialize(destination=f, format=output_format) # type: ignore[arg-type]
151 else:
152 store.serialize(destination=output_file, format=output_format)
154 return len(processed_entities), output_file
157def main(): # pragma: no cover
158 parser = argparse.ArgumentParser(
159 description='Extract a subset of data from a SPARQL endpoint',
160 formatter_class=RichHelpFormatter,
161 )
162 parser.add_argument('--endpoint', default='http://localhost:8890/sparql',
163 help='SPARQL endpoint URL (default: http://localhost:8890/sparql)')
165 discovery = parser.add_mutually_exclusive_group()
166 discovery.add_argument('--class', dest='class_uri',
167 help='Class URI to extract instances of (default: fabio:Expression)')
168 discovery.add_argument('--entities-file', dest='entities_file',
169 help='File with entity URIs to extract (one per line)')
171 parser.add_argument('--limit', type=int, default=1000,
172 help='Maximum number of initial entities to process (default: 1000)')
173 parser.add_argument('--output', default='output.nq',
174 help='Output file name (default: output.nq)')
175 parser.add_argument('--compress', action='store_true',
176 help='Compress output file using gzip')
177 parser.add_argument('--retries', type=int, default=5,
178 help='Maximum number of retries for failed queries (default: 5)')
179 parser.add_argument('--no-graphs', action='store_true',
180 help='Disable named graph queries and output N-Triples instead of N-Quads')
182 args = parser.parse_args()
184 if not args.class_uri and not args.entities_file:
185 args.class_uri = 'http://purl.org/spar/fabio/Expression'
187 try:
188 parsed_url = urlparse(args.endpoint)
189 if not all([parsed_url.scheme, parsed_url.netloc]):
190 raise ValueError("Invalid endpoint URL")
191 except Exception:
192 print(f"Error: Invalid endpoint URL: {args.endpoint}")
193 return 1
195 try:
196 entity_count, final_output_file = extract_subset(
197 args.endpoint,
198 args.limit,
199 args.output,
200 args.compress,
201 args.retries,
202 class_uri=args.class_uri,
203 entities_file=args.entities_file,
204 use_graphs=not args.no_graphs,
205 )
207 print(f"Extraction complete. Processed {entity_count} entities.")
208 print(f"Output saved to {final_output_file}")
210 return 0
211 except Exception as e:
212 print(f"Error: {e}")
213 return 1
216if __name__ == "__main__": # pragma: no cover
217 sys.exit(main())