Coverage for heritrace/extensions.py: 100%
298 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-06-24 11:39 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-06-24 11:39 +0000
1# heritrace/extensions.py
3import json
4import logging
5import os
6import time
7from collections import defaultdict
8from datetime import datetime, timedelta
9from typing import Dict
10from urllib.parse import urlparse, urlunparse
12import yaml
13from flask import Flask, current_app, g, redirect, session, url_for
14from flask_babel import Babel
15from flask_login import LoginManager
16from flask_login.signals import user_loaded_from_cookie
17from heritrace.models import User
18from heritrace.services.resource_lock_manager import ResourceLockManager
19from heritrace.uri_generator.uri_generator import URIGenerator
20from heritrace.utils.filters import Filter
21from rdflib import Graph
22from rdflib_ocdm.counter_handler.counter_handler import CounterHandler
23from redis import Redis
24from SPARQLWrapper import JSON, SPARQLWrapper
25from time_agnostic_library.support import generate_config_file
27# Global variables
28initialization_done = False
29dataset_endpoint = None
30provenance_endpoint = None
31sparql = None
32provenance_sparql = None
33change_tracking_config = None
34form_fields_cache = None
35custom_filter = None
36redis_client = None
37display_rules = None
38dataset_is_quadstore = None
39shacl_graph = None
40classes_with_multiple_shapes = None
43class SPARQLWrapperWithRetry(SPARQLWrapper):
44 """
45 Extension of SPARQLWrapper that includes automatic retry functionality and timeout handling.
46 Uses SPARQLWrapper's built-in timeout functionality.
47 """
48 def __init__(self, endpoint, **kwargs):
49 self.max_attempts = kwargs.pop('max_attempts', 3)
50 self.initial_delay = kwargs.pop('initial_delay', 1.0)
51 self.backoff_factor = kwargs.pop('backoff_factor', 2.0)
52 query_timeout = kwargs.pop('timeout', 5.0)
54 super().__init__(endpoint, **kwargs)
56 self.setTimeout(int(query_timeout))
58 def query(self):
59 """
60 Override the query method to include retry logic with SPARQLWrapper's built-in timeout.
61 Returns the original SPARQLWrapper.QueryResult so that convert() can be called on it.
62 """
63 logger = logging.getLogger(__name__)
65 attempt = 1
66 delay = self.initial_delay
67 last_exception = None
69 while attempt <= self.max_attempts:
70 try:
71 result = super().query()
72 return result
74 except Exception as e:
75 last_exception = e
76 logger.warning(f"SPARQL query attempt {attempt}/{self.max_attempts} failed: {str(e)}")
78 if attempt < self.max_attempts:
79 logger.info(f"Retrying in {delay:.2f} seconds...")
80 time.sleep(delay)
81 delay *= self.backoff_factor
83 attempt += 1
85 logger.error(f"All {self.max_attempts} SPARQL query attempts failed")
86 raise last_exception
88def init_extensions(app: Flask, babel: Babel, login_manager: LoginManager, redis: Redis):
89 """
90 Initialize Flask extensions and configure shared objects.
92 Args:
93 app: Flask application instance
94 babel: Babel extension instance
95 login_manager: LoginManager instance
96 redis: Redis client instance
97 """
98 global redis_client
100 redis_client = redis
102 # Initialize Babel
103 babel.init_app(
104 app=app,
105 locale_selector=lambda: session.get('lang', 'en'),
106 default_translation_directories=app.config['BABEL_TRANSLATION_DIRECTORIES']
107 )
109 # Initialize LoginManager
110 init_login_manager(app, login_manager)
112 # Initialize SPARQL endpoints and other services
113 init_sparql_services(app)
115 # Initialize filters
116 init_filters(app)
118 # Register before_request handlers
119 init_request_handlers(app)
121 # Store extensions in app context
122 app.babel = babel
123 app.login_manager = login_manager
124 app.redis_client = redis_client
126def init_login_manager(app, login_manager: LoginManager):
127 """Configure the Flask-Login extension."""
128 login_manager.init_app(app)
129 login_manager.login_view = 'auth.login'
130 login_manager.unauthorized_handler(lambda: redirect(url_for('auth.login')))
132 @login_manager.user_loader
133 def load_user(user_id):
134 user_name = session.get('user_name', 'Unknown User')
135 return User(id=user_id, name=user_name, orcid=user_id)
137 @user_loaded_from_cookie.connect
138 def rotate_session_token(sender, user):
139 session.modified = True
141def initialize_change_tracking_config(app: Flask, adjusted_dataset_endpoint=None, adjusted_provenance_endpoint=None):
142 """
143 Initialize and return the change tracking configuration JSON.
144 Uses pre-adjusted endpoints if provided to avoid redundant adjustments.
146 Args:
147 app: Flask application instance
148 adjusted_dataset_endpoint: Dataset endpoint URL already adjusted for Docker
149 adjusted_provenance_endpoint: Provenance endpoint URL already adjusted for Docker
151 Returns:
152 dict: The loaded configuration dictionary
153 """
154 config_needs_generation = False
155 config_path = None
156 config = None
158 # Check if we have a config path in app.config
159 if 'CHANGE_TRACKING_CONFIG' in app.config:
160 config_path = app.config['CHANGE_TRACKING_CONFIG']
161 if not os.path.exists(config_path):
162 app.logger.warning(f"Change tracking configuration file not found at specified path: {config_path}")
163 config_needs_generation = True
164 else:
165 config_needs_generation = True
166 config_path = os.path.join(app.instance_path, 'change_tracking_config.json')
167 os.makedirs(app.instance_path, exist_ok=True)
169 if config_needs_generation:
170 dataset_urls = [adjusted_dataset_endpoint] if adjusted_dataset_endpoint else []
171 provenance_urls = [adjusted_provenance_endpoint] if adjusted_provenance_endpoint else []
173 cache_endpoint = adjust_endpoint_url(app.config.get('CACHE_ENDPOINT', ''))
174 cache_update_endpoint = adjust_endpoint_url(app.config.get('CACHE_UPDATE_ENDPOINT', ''))
176 db_triplestore = app.config.get('DATASET_DB_TRIPLESTORE', '').lower()
177 text_index_enabled = app.config.get('DATASET_DB_TEXT_INDEX_ENABLED', False)
179 blazegraph_search = db_triplestore == 'blazegraph' and text_index_enabled
180 fuseki_search = db_triplestore == 'fuseki' and text_index_enabled
181 virtuoso_search = db_triplestore == 'virtuoso' and text_index_enabled
183 graphdb_connector = '' #TODO: Add graphdb support
185 try:
186 config = generate_config_file(
187 config_path=config_path,
188 dataset_urls=dataset_urls,
189 dataset_dirs=app.config.get('DATASET_DIRS', []),
190 dataset_is_quadstore=app.config.get('DATASET_IS_QUADSTORE', False),
191 provenance_urls=provenance_urls,
192 provenance_is_quadstore=app.config.get('PROVENANCE_IS_QUADSTORE', False),
193 provenance_dirs=app.config.get('PROVENANCE_DIRS', []),
194 blazegraph_full_text_search=blazegraph_search,
195 fuseki_full_text_search=fuseki_search,
196 virtuoso_full_text_search=virtuoso_search,
197 graphdb_connector_name=graphdb_connector,
198 cache_endpoint=cache_endpoint,
199 cache_update_endpoint=cache_update_endpoint
200 )
201 app.logger.info(f"Generated new change tracking configuration at: {config_path}")
202 except Exception as e:
203 raise RuntimeError(f"Failed to generate change tracking configuration: {str(e)}")
205 # Load and validate the configuration
206 try:
207 if not config:
208 with open(config_path, 'r', encoding='utf8') as f:
209 config = json.load(f)
211 # Adjust cache URLs if needed
212 if config['cache_triplestore_url'].get('endpoint'):
213 config['cache_triplestore_url']['endpoint'] = adjust_endpoint_url(
214 config['cache_triplestore_url']['endpoint']
215 )
217 if config['cache_triplestore_url'].get('update_endpoint'):
218 config['cache_triplestore_url']['update_endpoint'] = adjust_endpoint_url(
219 config['cache_triplestore_url']['update_endpoint']
220 )
222 except json.JSONDecodeError as e:
223 raise RuntimeError(f"Invalid change tracking configuration JSON at {config_path}: {str(e)}")
224 except Exception as e:
225 raise RuntimeError(f"Error reading change tracking configuration at {config_path}: {str(e)}")
227 app.config['CHANGE_TRACKING_CONFIG'] = config_path
228 return config
230def need_initialization(app: Flask):
231 """
232 Check if counter handler initialization is needed.
233 """
234 uri_generator = app.config['URI_GENERATOR']
236 if not hasattr(uri_generator, "counter_handler"):
237 return False
239 cache_file = app.config['CACHE_FILE']
240 cache_validity_days = app.config['CACHE_VALIDITY_DAYS']
242 if not os.path.exists(cache_file):
243 return True
245 try:
246 with open(cache_file, 'r', encoding='utf8') as f:
247 cache = json.load(f)
249 last_init = datetime.fromisoformat(cache['last_initialization'])
250 return datetime.now() - last_init > timedelta(days=cache_validity_days)
251 except Exception:
252 return True
254def update_cache(app: Flask):
255 """
256 Update the cache file with current initialization timestamp.
257 """
258 cache_file = app.config['CACHE_FILE']
259 cache = {
260 'last_initialization': datetime.now().isoformat(),
261 'version': '1.0'
262 }
263 with open(cache_file, 'w', encoding='utf8') as f:
264 json.dump(cache, f, ensure_ascii=False, indent=4)
266def initialize_counter_handler(app: Flask):
267 """
268 Initialize the counter handler for URI generation if needed.
269 """
270 if not need_initialization(app):
271 return
273 uri_generator: URIGenerator = app.config['URI_GENERATOR']
274 counter_handler: CounterHandler = uri_generator.counter_handler
276 # Inizializza i contatori specifici dell'URI generator
277 uri_generator.initialize_counters(sparql)
279 # Query per contare gli snapshot nella provenance
280 # Contiamo il numero di wasDerivedFrom per ogni entità e aggiungiamo 1
281 # (poiché il primo snapshot non ha wasDerivedFrom)
282 prov_query = """
283 SELECT ?entity (COUNT(DISTINCT ?snapshot) as ?count)
284 WHERE {
285 ?snapshot a <http://www.w3.org/ns/prov#Entity> ;
286 <http://www.w3.org/ns/prov#specializationOf> ?entity .
287 OPTIONAL {
288 ?snapshot <http://www.w3.org/ns/prov#wasDerivedFrom> ?prev .
289 }
290 }
291 GROUP BY ?entity
292 """
294 # Esegui query sulla provenance e imposta i contatori degli snapshot
295 provenance_sparql.setQuery(prov_query)
296 provenance_sparql.setReturnFormat(JSON)
297 prov_results = provenance_sparql.query().convert()
299 for result in prov_results["results"]["bindings"]:
300 entity = result["entity"]["value"]
301 count = int(result["count"]["value"])
302 counter_handler.set_counter(count, entity)
304 update_cache(app)
306def identify_classes_with_multiple_shapes():
307 """
308 Identify classes that have multiple VISIBLE shapes associated with them.
309 Only returns classes where multiple shapes are actually visible to avoid unnecessary processing.
311 Returns:
312 Set[str]: Set of class URIs that have multiple visible shapes
313 """
314 global display_rules, shacl_graph
316 if not display_rules or not shacl_graph:
317 return set()
319 from heritrace.utils.display_rules_utils import is_entity_type_visible
321 class_to_shapes = defaultdict(set)
323 for rule in display_rules:
324 target = rule.get("target", {})
326 if "class" in target:
327 class_uri = target["class"]
328 if shacl_graph:
329 query_string = f"""
330 SELECT DISTINCT ?shape WHERE {{
331 ?shape <http://www.w3.org/ns/shacl#targetClass> <{class_uri}> .
332 }}
333 """
334 results = shacl_graph.query(query_string)
335 for row in results:
336 shape_uri = str(row.shape)
337 entity_key = (class_uri, shape_uri)
338 if is_entity_type_visible(entity_key):
339 class_to_shapes[class_uri].add(shape_uri)
341 elif "shape" in target:
342 shape_uri = target["shape"]
343 if shacl_graph:
344 query_string = f"""
345 SELECT DISTINCT ?class WHERE {{
346 <{shape_uri}> <http://www.w3.org/ns/shacl#targetClass> ?class .
347 }}
348 """
349 results = shacl_graph.query(query_string)
350 for row in results:
351 class_uri = str(row[0])
352 entity_key = (class_uri, shape_uri)
353 if is_entity_type_visible(entity_key):
354 class_to_shapes[class_uri].add(shape_uri)
356 return {class_uri for class_uri, shapes in class_to_shapes.items() if len(shapes) > 1}
358def initialize_global_variables(app: Flask):
359 """
360 Initialize all global variables including form fields cache, display rules,
361 and dataset configuration from SHACL shapes graph and configuration files.
363 Args:
364 app: Flask application instance
365 """
366 global shacl_graph, form_fields_cache, display_rules, dataset_is_quadstore, classes_with_multiple_shapes
368 try:
369 dataset_is_quadstore = app.config.get('DATASET_IS_QUADSTORE', False)
371 if app.config.get('DISPLAY_RULES_PATH'):
372 if not os.path.exists(app.config['DISPLAY_RULES_PATH']):
373 app.logger.warning(f"Display rules file not found at: {app.config['DISPLAY_RULES_PATH']}")
374 else:
375 try:
376 with open(app.config['DISPLAY_RULES_PATH'], 'r') as f:
377 yaml_content = yaml.safe_load(f)
378 display_rules = yaml_content['rules']
379 except Exception as e:
380 app.logger.error(f"Error loading display rules: {str(e)}")
381 raise RuntimeError(f"Failed to load display rules: {str(e)}")
383 if app.config.get('SHACL_PATH'):
384 if not os.path.exists(app.config['SHACL_PATH']):
385 app.logger.warning(f"SHACL file not found at: {app.config['SHACL_PATH']}")
386 return
388 if form_fields_cache is not None:
389 return
391 try:
392 shacl_graph = Graph()
393 shacl_graph.parse(source=app.config['SHACL_PATH'], format="turtle")
395 from heritrace.utils.shacl_utils import \
396 get_form_fields_from_shacl
397 form_fields_cache = get_form_fields_from_shacl(shacl_graph, display_rules, app=app)
398 except Exception as e:
399 app.logger.error(f"Error initializing form fields from SHACL: {str(e)}")
400 raise RuntimeError(f"Failed to initialize form fields: {str(e)}")
402 classes_with_multiple_shapes = identify_classes_with_multiple_shapes()
404 app.logger.info("Global variables initialized successfully")
406 except Exception as e:
407 app.logger.error(f"Error during global variables initialization: {str(e)}")
408 raise RuntimeError(f"Global variables initialization failed: {str(e)}")
410def init_sparql_services(app: Flask):
411 """Initialize SPARQL endpoints and related services."""
412 global initialization_done, dataset_endpoint, provenance_endpoint, sparql, provenance_sparql, change_tracking_config
414 if not initialization_done:
415 dataset_endpoint = adjust_endpoint_url(app.config['DATASET_DB_URL'])
416 provenance_endpoint = adjust_endpoint_url(app.config['PROVENANCE_DB_URL'])
418 sparql = SPARQLWrapperWithRetry(dataset_endpoint)
419 provenance_sparql = SPARQLWrapperWithRetry(provenance_endpoint)
421 change_tracking_config = initialize_change_tracking_config(
422 app,
423 adjusted_dataset_endpoint=dataset_endpoint,
424 adjusted_provenance_endpoint=provenance_endpoint
425 )
427 initialize_counter_handler(app)
428 initialize_global_variables(app)
429 initialization_done = True
431def init_filters(app: Flask):
432 """Initialize custom template filters."""
433 global custom_filter
435 with open(os.path.join("resources", "context.json"), "r") as config_file:
436 context = json.load(config_file)["@context"]
438 display_rules = None
439 if app.config["DISPLAY_RULES_PATH"]:
440 with open(app.config["DISPLAY_RULES_PATH"], 'r') as f:
441 yaml_content = yaml.safe_load(f)
442 display_rules = yaml_content.get('rules', [])
444 custom_filter = Filter(context, display_rules, dataset_endpoint)
446 app.jinja_env.filters['human_readable_predicate'] = custom_filter.human_readable_predicate
447 app.jinja_env.filters['human_readable_class'] = custom_filter.human_readable_class
448 app.jinja_env.filters['human_readable_entity'] = custom_filter.human_readable_entity
449 app.jinja_env.filters['human_readable_primary_source'] = custom_filter.human_readable_primary_source
450 app.jinja_env.filters['format_datetime'] = custom_filter.human_readable_datetime
451 app.jinja_env.filters['split_ns'] = custom_filter.split_ns
452 app.jinja_env.filters['format_source_reference'] = custom_filter.format_source_reference
453 app.jinja_env.filters['format_agent_reference'] = custom_filter.format_agent_reference
455def init_request_handlers(app):
456 """Initialize before_request and teardown_request handlers."""
458 @app.before_request
459 def initialize_lock_manager():
460 """Initialize the resource lock manager for each request."""
461 if not hasattr(g, 'resource_lock_manager'):
462 g.resource_lock_manager = ResourceLockManager(redis_client)
464 @app.teardown_appcontext
465 def close_redis_connection(error):
466 """Close Redis connection when the request context ends."""
467 if hasattr(g, 'resource_lock_manager'):
468 del g.resource_lock_manager
470def adjust_endpoint_url(url: str) -> str:
471 """
472 Adjust endpoint URLs to work properly within Docker containers.
474 Args:
475 url: The endpoint URL to adjust
477 Returns:
478 The adjusted URL if running in Docker, original URL otherwise
479 """
480 if not running_in_docker():
481 return url
483 local_patterns = ['localhost', '127.0.0.1', '0.0.0.0']
484 parsed_url = urlparse(url)
486 if any(pattern in parsed_url.netloc for pattern in local_patterns):
487 netloc_parts = parsed_url.netloc.split(':')
488 new_netloc = f'host.docker.internal:{netloc_parts[1]}' if len(netloc_parts) > 1 else 'host.docker.internal'
489 url_parts = list(parsed_url)
490 url_parts[1] = new_netloc
491 return urlunparse(url_parts)
493 return url
495def running_in_docker() -> bool:
496 """Check if the application is running inside a Docker container."""
497 return os.path.exists('/.dockerenv')
499def get_dataset_endpoint() -> str:
500 """Get the configured dataset endpoint URL."""
502 global dataset_endpoint
503 return dataset_endpoint
505def get_sparql() -> SPARQLWrapperWithRetry:
506 """Get the configured SPARQL wrapper for the dataset endpoint with built-in retry mechanism."""
508 global sparql
509 return sparql
511def get_provenance_endpoint() -> str:
512 """Get the configured provenance endpoint URL."""
514 global provenance_endpoint
515 return provenance_endpoint
517def get_provenance_sparql() -> SPARQLWrapperWithRetry:
518 """Get the configured SPARQL wrapper for the provenance endpoint with built-in retry mechanism."""
520 global provenance_sparql
521 return provenance_sparql
523def get_counter_handler() -> CounterHandler:
524 """
525 Get the configured CounterHandler instance from the URIGenerator.
526 Assumes URIGenerator and its counter_handler are initialized in app.config.
527 """
528 uri_generator: URIGenerator = current_app.config.get('URI_GENERATOR')
529 if uri_generator and hasattr(uri_generator, 'counter_handler'):
530 return uri_generator.counter_handler
531 else:
532 # Handle cases where it might not be initialized yet or configured
533 current_app.logger.error("CounterHandler not found in URIGenerator config.")
534 raise RuntimeError("CounterHandler is not available. Initialization might have failed.")
536def get_custom_filter() -> Filter:
537 """Get the configured custom filter instance."""
539 global custom_filter
540 return custom_filter
542def get_change_tracking_config() -> Dict:
543 """Get the change tracking configuration."""
545 global change_tracking_config
546 return change_tracking_config
548def get_display_rules() -> Dict:
549 """Get the display_rules configuration."""
551 global display_rules
552 return display_rules
554def get_form_fields() -> Dict:
555 """Get the form_fields configuration."""
557 global form_fields_cache
558 return form_fields_cache
560def get_dataset_is_quadstore() -> bool:
561 """Check if the dataset is a quadstore."""
563 global dataset_is_quadstore
564 return dataset_is_quadstore
566def get_shacl_graph() -> Graph:
567 """Get the SHACL shapes graph."""
569 global shacl_graph
570 return shacl_graph
572def get_classes_with_multiple_shapes() -> set:
573 """Get the set of classes that have multiple visible shapes."""
575 global classes_with_multiple_shapes
576 return classes_with_multiple_shapes or set()