Coverage for heritrace / extensions.py: 99%
271 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-07-02 10:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-07-02 10:16 +0000
1# SPDX-FileCopyrightText: 2024-2026 Arcangelo Massari <arcangelo.massari@unibo.it>
2#
3# SPDX-License-Identifier: ISC
5import json
6import os
7from collections import defaultdict
8from dataclasses import dataclass
9from datetime import datetime, timedelta, timezone
10from pathlib import Path
11from typing import cast
12from urllib.parse import urlparse, urlunparse
14import yaml
15from flask import Flask, current_app, g, redirect, session, url_for
16from flask_babel import Babel
17from flask_login import LoginManager
18from flask_login.signals import user_loaded_from_cookie
19from rdflib import Graph
20from rdflib_ocdm.counter_handler.counter_handler import CounterHandler
21from redis import Redis
22from redis.exceptions import RedisError
23from SPARQLWrapper import JSON
24from time_agnostic_library.support import generate_config_file
26from heritrace.models import User
27from heritrace.services.resource_lock_manager import ResourceLockManager
28from heritrace.sparql import SPARQLWrapperWithRetry, get_sparql_bindings, select_results
29from heritrace.uri_generator.uri_generator import CounterBasedURIGenerator
30from heritrace.utils.filters import Filter, split_namespace
33@dataclass(frozen=True)
34class AppState:
35 dataset_endpoint: str
36 provenance_endpoint: str
37 sparql: SPARQLWrapperWithRetry
38 provenance_sparql: SPARQLWrapperWithRetry
39 change_tracking_config: dict
40 custom_filter: Filter
41 display_rules: list[dict]
42 form_fields_cache: dict
43 dataset_is_quadstore: bool
44 shacl_graph: Graph
45 classes_with_multiple_shapes: set[str]
48def get_app_state() -> AppState:
49 return current_app.extensions["heritrace"]
52def init_extensions(
53 app: Flask, babel: Babel, login_manager: LoginManager, redis: Redis
54) -> None:
55 babel.init_app(
56 app=app,
57 locale_selector=lambda: session.get("lang", "en"),
58 default_translation_directories=str(
59 Path(__file__).resolve().parent.parent / "babel" / "translations"
60 ),
61 )
63 init_login_manager(app, login_manager)
65 (
66 dataset_endpoint,
67 provenance_endpoint,
68 sparql,
69 provenance_sparql,
70 change_tracking_config,
71 ) = init_sparql_services(app)
72 initialize_counter_handler(app, redis, sparql, provenance_sparql)
74 app.extensions["heritrace"] = AppState(
75 dataset_endpoint=dataset_endpoint,
76 provenance_endpoint=provenance_endpoint,
77 sparql=sparql,
78 provenance_sparql=provenance_sparql,
79 change_tracking_config=change_tracking_config,
80 custom_filter=cast("Filter", None),
81 display_rules=[],
82 form_fields_cache={},
83 dataset_is_quadstore=False,
84 shacl_graph=Graph(),
85 classes_with_multiple_shapes=set(),
86 )
88 (
89 display_rules,
90 form_fields_cache,
91 dataset_is_quadstore,
92 shacl_graph,
93 classes_with_multiple_shapes,
94 ) = initialize_global_variables(app)
95 custom_filter = init_filters(app, display_rules, dataset_endpoint)
96 init_request_handlers(app, redis)
98 app.extensions["heritrace"] = AppState(
99 dataset_endpoint=dataset_endpoint,
100 provenance_endpoint=provenance_endpoint,
101 sparql=sparql,
102 provenance_sparql=provenance_sparql,
103 change_tracking_config=change_tracking_config,
104 custom_filter=custom_filter,
105 display_rules=display_rules,
106 form_fields_cache=form_fields_cache,
107 dataset_is_quadstore=dataset_is_quadstore,
108 shacl_graph=shacl_graph,
109 classes_with_multiple_shapes=classes_with_multiple_shapes,
110 )
111 app.extensions["login_manager"] = login_manager
112 app.extensions["redis_client"] = redis
115def init_login_manager(app: Flask, login_manager: LoginManager) -> None:
116 login_manager.init_app(app)
117 login_manager.login_view = "auth.login" # type: ignore[reportAttributeAccessIssue]
118 login_manager.unauthorized_handler(lambda: redirect(url_for("auth.login")))
120 @login_manager.user_loader
121 def load_user(user_id: str) -> User:
122 user_name = session.get("user_name", "Unknown User")
123 return User(user_id=user_id, name=user_name, orcid=user_id)
125 @user_loaded_from_cookie.connect
126 def rotate_session_token(_sender: object, _user: object) -> None:
127 session.modified = True
130def initialize_change_tracking_config(
131 app: Flask,
132 adjusted_dataset_endpoint: str | None = None,
133 adjusted_provenance_endpoint: str | None = None,
134) -> dict:
135 config_needs_generation = False
136 config_path = None
137 config = None
139 if "CHANGE_TRACKING_CONFIG" in app.config:
140 config_path = app.config["CHANGE_TRACKING_CONFIG"]
141 if not Path(config_path).exists():
142 app.logger.warning(
143 "Change tracking configuration file not found at specified path: %s",
144 config_path,
145 )
146 config_needs_generation = True
147 else:
148 config_needs_generation = True
149 config_path = str(Path(app.instance_path) / "change_tracking_config.json")
150 Path(app.instance_path).mkdir(parents=True, exist_ok=True)
152 if config_needs_generation:
153 dataset_urls = [adjusted_dataset_endpoint] if adjusted_dataset_endpoint else []
154 provenance_urls = (
155 [adjusted_provenance_endpoint] if adjusted_provenance_endpoint else []
156 )
158 db_triplestore = app.config.get("DATASET_DB_TRIPLESTORE", "").lower()
159 text_index_enabled = app.config.get("DATASET_DB_TEXT_INDEX_ENABLED", False)
161 blazegraph_search = db_triplestore == "blazegraph" and text_index_enabled
162 fuseki_search = db_triplestore == "fuseki" and text_index_enabled
163 virtuoso_search = db_triplestore == "virtuoso" and text_index_enabled
165 graphdb_connector = "" # TODO(@arcangelo-massari): Add graphdb support
166 # https://github.com/opencitations/heritrace/issues/1
168 try:
169 config = generate_config_file(
170 config_path=config_path,
171 dataset_urls=dataset_urls,
172 dataset_dirs=app.config.get("DATASET_DIRS", []),
173 dataset_is_quadstore=app.config.get("DATASET_IS_QUADSTORE", False),
174 provenance_urls=provenance_urls,
175 provenance_is_quadstore=app.config.get(
176 "PROVENANCE_IS_QUADSTORE", False
177 ),
178 provenance_dirs=app.config.get("PROVENANCE_DIRS", []),
179 blazegraph_full_text_search=blazegraph_search,
180 fuseki_full_text_search=fuseki_search,
181 virtuoso_full_text_search=virtuoso_search,
182 graphdb_connector_name=graphdb_connector,
183 )
184 app.logger.info(
185 "Generated new change tracking configuration at: %s", config_path
186 )
187 except OSError as e:
188 msg = f"Failed to generate change tracking configuration: {e!s}"
189 raise RuntimeError(msg) from e
191 try:
192 if not config:
193 with Path(config_path).open(encoding="utf8") as f:
194 config = json.load(f)
196 except json.JSONDecodeError as e:
197 msg = f"Invalid change tracking configuration JSON at {config_path}: {e!s}"
198 raise RuntimeError(msg) from e
199 except OSError as e:
200 msg = f"Error reading change tracking configuration at {config_path}: {e!s}"
201 raise RuntimeError(msg) from e
203 app.config["CHANGE_TRACKING_CONFIG"] = config_path
204 return config
207def need_initialization(app: Flask, redis: Redis) -> bool:
208 redis_url = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
209 is_external_redis = redis_url != "redis://localhost:6379/0"
211 if is_external_redis:
212 app.logger.info(
213 "Using external Redis at %s - skipping counter initialization", redis_url
214 )
215 return False
217 cache_validity_days = app.config["CACHE_VALIDITY_DAYS"]
219 try:
220 last_init_raw: str | None = redis.get("heritrace:last_initialization") # type: ignore[assignment]
221 if not last_init_raw:
222 return True
224 last_init = datetime.fromisoformat(last_init_raw)
225 return datetime.now(tz=timezone.utc) - last_init > timedelta(
226 days=cache_validity_days
227 )
228 except (RedisError, ValueError):
229 return True
232def update_cache(_app: Flask, redis: Redis) -> None:
233 current_time = datetime.now(tz=timezone.utc).isoformat()
234 redis.set("heritrace:last_initialization", current_time)
235 redis.set("heritrace:cache_version", "1.0")
238def initialize_counter_handler(
239 app: Flask,
240 redis: Redis,
241 sparql: SPARQLWrapperWithRetry,
242 provenance_sparql: SPARQLWrapperWithRetry,
243) -> None:
244 if not need_initialization(app, redis):
245 return
247 uri_generator = app.config["URI_GENERATOR"]
248 if isinstance(uri_generator, CounterBasedURIGenerator):
249 uri_generator.initialize_counters(sparql)
251 counter_handler = app.config["COUNTER_HANDLER"]
253 prov_query = """
254 SELECT ?entity (COUNT(DISTINCT ?snapshot) as ?count)
255 WHERE {
256 ?snapshot a <http://www.w3.org/ns/prov#Entity> ;
257 <http://www.w3.org/ns/prov#specializationOf> ?entity .
258 OPTIONAL {
259 ?snapshot <http://www.w3.org/ns/prov#wasDerivedFrom> ?prev .
260 }
261 }
262 GROUP BY ?entity
263 """
265 provenance_sparql.setQuery(prov_query)
266 provenance_sparql.setReturnFormat(JSON)
267 prov_bindings = get_sparql_bindings(provenance_sparql.query().convert())
269 for result in prov_bindings:
270 entity = result["entity"]["value"]
271 count = int(result["count"]["value"])
272 counter_handler.set_counter(count, entity)
274 update_cache(app, redis)
277def identify_classes_with_multiple_shapes(
278 display_rules: list[dict], shacl_graph: Graph
279) -> set[str]:
280 if not display_rules or not shacl_graph:
281 return set()
283 from heritrace.utils.display_rules_utils import ( # noqa: PLC0415
284 is_entity_type_visible,
285 )
287 class_to_shapes: defaultdict[str, set[str]] = defaultdict(set)
289 for rule in display_rules:
290 target = rule.get("target", {})
292 if "class" in target:
293 class_uri = target["class"]
294 query_string = f"""
295 SELECT DISTINCT ?shape WHERE {{
296 ?shape <http://www.w3.org/ns/shacl#targetClass> <{class_uri}> .
297 }}
298 """
299 results = shacl_graph.query(query_string)
300 for row in select_results(results):
301 shape_uri = str(row.shape)
302 entity_key = (class_uri, shape_uri)
303 if is_entity_type_visible(entity_key):
304 class_to_shapes[class_uri].add(shape_uri)
306 elif "shape" in target:
307 shape_uri = target["shape"]
308 query_string = f"""
309 SELECT DISTINCT ?class WHERE {{
310 <{shape_uri}> <http://www.w3.org/ns/shacl#targetClass> ?class .
311 }}
312 """
313 results = shacl_graph.query(query_string)
314 for row in select_results(results):
315 class_uri = str(row[0])
316 entity_key = (class_uri, shape_uri)
317 if is_entity_type_visible(entity_key):
318 class_to_shapes[class_uri].add(shape_uri)
320 return {
321 class_uri for class_uri, shapes in class_to_shapes.items() if len(shapes) > 1
322 }
325def initialize_global_variables(
326 app: Flask,
327) -> tuple[list[dict], dict, bool, Graph, set[str]]:
328 try:
329 dataset_is_quadstore = app.config.get("DATASET_IS_QUADSTORE", False)
331 display_rules: list[dict] = []
332 if app.config.get("DISPLAY_RULES_PATH"):
333 if not app.config["DISPLAY_RULES_PATH"].exists():
334 app.logger.warning(
335 "Display rules file not found at: %s",
336 app.config["DISPLAY_RULES_PATH"],
337 )
338 else:
339 try:
340 with app.config["DISPLAY_RULES_PATH"].open() as f:
341 yaml_content = yaml.safe_load(f)
342 display_rules = yaml_content["rules"]
343 except yaml.YAMLError as e:
344 app.logger.exception("Error loading display rules")
345 msg = f"Failed to load display rules: {e!s}"
346 raise RuntimeError(msg) from e
348 shacl_graph = Graph()
349 form_fields_cache: dict = {}
350 if app.config.get("SHACL_PATH"):
351 if not app.config["SHACL_PATH"].exists():
352 app.logger.warning(
353 "SHACL file not found at: %s", app.config["SHACL_PATH"]
354 )
355 else:
356 try:
357 shacl_graph.parse(source=app.config["SHACL_PATH"], format="turtle")
359 from heritrace.utils.shacl_utils import ( # noqa: PLC0415
360 get_form_fields_from_shacl,
361 )
363 form_fields_cache = get_form_fields_from_shacl(
364 shacl_graph, display_rules, app=app
365 )
366 except (OSError, ValueError) as e:
367 app.logger.exception("Error initializing form fields from SHACL")
368 msg = f"Failed to initialize form fields: {e!s}"
369 raise RuntimeError(msg) from e
371 classes_with_multiple_shapes = identify_classes_with_multiple_shapes(
372 display_rules, shacl_graph
373 )
375 app.logger.info("Global variables initialized successfully")
377 except RuntimeError:
378 raise
379 except (OSError, yaml.YAMLError, ValueError) as e:
380 app.logger.exception("Error during global variables initialization")
381 msg = f"Global variables initialization failed: {e!s}"
382 raise RuntimeError(msg) from e
383 else:
384 return (
385 display_rules,
386 form_fields_cache,
387 dataset_is_quadstore,
388 shacl_graph,
389 classes_with_multiple_shapes,
390 )
393def init_sparql_services(
394 app: Flask,
395) -> tuple[str, str, SPARQLWrapperWithRetry, SPARQLWrapperWithRetry, dict]:
396 dataset_endpoint = adjust_endpoint_url(app.config["DATASET_DB_URL"])
397 provenance_endpoint = adjust_endpoint_url(app.config["PROVENANCE_DB_URL"])
399 sparql = SPARQLWrapperWithRetry(dataset_endpoint, timeout=30.0)
400 provenance_sparql = SPARQLWrapperWithRetry(provenance_endpoint, timeout=30.0)
402 change_tracking_config = initialize_change_tracking_config(
403 app,
404 adjusted_dataset_endpoint=dataset_endpoint,
405 adjusted_provenance_endpoint=provenance_endpoint,
406 )
408 return (
409 dataset_endpoint,
410 provenance_endpoint,
411 sparql,
412 provenance_sparql,
413 change_tracking_config,
414 )
417def init_filters(
418 app: Flask, display_rules: list[dict], dataset_endpoint: str
419) -> Filter:
420 with (Path(__file__).parent / "utils" / "context.json").open() as config_file:
421 context = json.load(config_file)["@context"]
423 custom_filter = Filter(context, display_rules or None, dataset_endpoint)
425 app.jinja_env.filters["human_readable_predicate"] = (
426 custom_filter.human_readable_predicate
427 )
428 app.jinja_env.filters["human_readable_class"] = custom_filter.human_readable_class
429 app.jinja_env.filters["human_readable_entity"] = custom_filter.human_readable_entity
430 app.jinja_env.filters["human_readable_primary_source"] = (
431 custom_filter.human_readable_primary_source
432 )
433 app.jinja_env.filters["format_datetime"] = custom_filter.human_readable_datetime
434 app.jinja_env.filters["split_ns"] = split_namespace
435 app.jinja_env.filters["format_source_reference"] = (
436 custom_filter.format_source_reference
437 )
438 app.jinja_env.filters["format_agent_reference"] = (
439 custom_filter.format_agent_reference
440 )
441 return custom_filter
444def init_request_handlers(app: Flask, redis: Redis) -> None:
445 @app.before_request
446 def initialize_lock_manager() -> None:
447 if not hasattr(g, "resource_lock_manager"):
448 g.resource_lock_manager = ResourceLockManager(redis)
450 @app.teardown_appcontext
451 def close_redis_connection(_error: BaseException | None) -> None:
452 if hasattr(g, "resource_lock_manager"):
453 del g.resource_lock_manager
456def adjust_endpoint_url(url: str) -> str:
457 if not running_in_docker():
458 return url
460 local_patterns = ["localhost", "127.0.0.1", "0.0.0.0"]
461 parsed_url = urlparse(url)
463 if any(pattern in parsed_url.netloc for pattern in local_patterns):
464 netloc_parts = parsed_url.netloc.split(":")
465 new_netloc = (
466 f"host.docker.internal:{netloc_parts[1]}"
467 if len(netloc_parts) > 1
468 else "host.docker.internal"
469 )
470 url_parts = list(parsed_url)
471 url_parts[1] = new_netloc
472 return urlunparse(url_parts)
474 return url
477def running_in_docker() -> bool:
478 return Path("/.dockerenv").exists()
481def get_dataset_endpoint() -> str:
482 return get_app_state().dataset_endpoint
485def get_sparql() -> SPARQLWrapperWithRetry:
486 return get_app_state().sparql
489def get_provenance_endpoint() -> str:
490 return get_app_state().provenance_endpoint
493def get_provenance_sparql() -> SPARQLWrapperWithRetry:
494 return get_app_state().provenance_sparql
497def get_counter_handler() -> CounterHandler:
498 uri_generator = current_app.config.get("URI_GENERATOR")
499 if not isinstance(uri_generator, CounterBasedURIGenerator):
500 current_app.logger.error("CounterHandler not found in URIGenerator config.")
501 msg = "CounterHandler is not available. Initialization might have failed."
502 raise TypeError(msg)
503 return uri_generator.counter_handler
506def get_custom_filter() -> Filter:
507 return get_app_state().custom_filter
510def get_change_tracking_config() -> dict:
511 return get_app_state().change_tracking_config
514def get_display_rules() -> list[dict]:
515 return get_app_state().display_rules
518def get_form_fields() -> dict:
519 return get_app_state().form_fields_cache
522def get_dataset_is_quadstore() -> bool:
523 return get_app_state().dataset_is_quadstore
526def get_shacl_graph() -> Graph:
527 return get_app_state().shacl_graph
530def get_classes_with_multiple_shapes() -> set[str]:
531 return get_app_state().classes_with_multiple_shapes