import random import string import time from abc import ABC, abstractmethod from datetime import datetime from math import floor, log10 import orjson from django.conf import settings from siphashc import siphash from core import r from core.db.processing import annotate_results from core.util import logs def remove_defaults(query_params): for field, value in list(query_params.items()): if field in settings.DRILLDOWN_DEFAULT_PARAMS: if value == settings.DRILLDOWN_DEFAULT_PARAMS[field]: del query_params[field] def add_defaults(query_params): for field, value in settings.DRILLDOWN_DEFAULT_PARAMS.items(): if field not in query_params: query_params[field] = value def dedup_list(data, check_keys): """ Remove duplicate dictionaries from list. """ seen = set() out = [] dup_count = 0 for x in data: dedupeKey = tuple(x[k] for k in check_keys if k in x) if dedupeKey in seen: dup_count += 1 continue if dup_count > 0: out.append({"type": "control", "hidden": dup_count}) dup_count = 0 out.append(x) seen.add(dedupeKey) if dup_count > 0: out.append({"type": "control", "hidden": dup_count}) return out class QueryError(Exception): pass class StorageBackend(ABC): def __init__(self, name): self.log = logs.get_logger(name) self.log.info(f"Initialising storage backend {name}") self.initialise_caching() self.initialise() @abstractmethod def initialise(self, **kwargs): pass def initialise_caching(self): hash_key = r.get("cache_hash_key") if not hash_key: letters = string.ascii_lowercase hash_key = "".join(random.choice(letters) for i in range(16)) self.log.debug(f"Created new hash key: {hash_key}") r.set("cache_hash_key", hash_key) else: hash_key = hash_key.decode("ascii") self.log.debug(f"Decoded hash key: {hash_key}") self.hash_key = hash_key @abstractmethod def construct_query(self, **kwargs): pass @abstractmethod def run_query(self, **kwargs): pass def parse_size(self, query_params, sizes): if "size" in query_params: size = query_params["size"] if size not in sizes: message = "Size is not permitted" message_class = "danger" return {"message": message, "class": message_class} size = int(size) else: size = 15 return size def parse_index(self, user, query_params, raise_error=False): if "index" in query_params: index = query_params["index"] if index == "main": index = settings.INDEX_MAIN else: if not user.has_perm(f"core.index_{index}"): message = f"Not permitted to search by this index: {index}" if raise_error: raise QueryError(message) message_class = "danger" return { "message": message, "class": message_class, } if index == "meta": index = settings.INDEX_META elif index == "internal": index = settings.INDEX_INT elif index == "restricted": if not user.has_perm("core.restricted_sources"): message = f"Not permitted to search by this index: {index}" if raise_error: raise QueryError(message) message_class = "danger" return { "message": message, "class": message_class, } index = settings.INDEX_RESTRICTED else: message = f"Index is not valid: {index}" if raise_error: raise QueryError(message) message_class = "danger" return { "message": message, "class": message_class, } else: index = settings.INDEX_MAIN return index def parse_query(self, query_params, tags, size, custom_query, add_bool, **kwargs): query_created = False if "query" in query_params: query = query_params["query"] search_query = self.construct_query(query, size, **kwargs) query_created = True else: if custom_query: search_query = custom_query else: search_query = self.construct_query(None, size, blank=True, **kwargs) if tags: # Get a blank search query if not query_created: search_query = self.construct_query(None, size, blank=True, **kwargs) query_created = True for item in tags: for tagname, tagvalue in item.items(): add_bool.append({tagname: tagvalue}) valid = self.check_valid_query(query_params, custom_query) if isinstance(valid, dict): return valid return search_query def check_valid_query(self, query_params, custom_query): required_any = ["query", "tags"] if not any([field in query_params.keys() for field in required_any]): if not custom_query: message = "Empty query!" message_class = "warning" return {"message": message, "class": message_class} def parse_source(self, user, query_params, raise_error=False): source = None if "source" in query_params: source = query_params["source"] if source in settings.SOURCES_RESTRICTED: if not user.has_perm("core.restricted_sources"): message = f"Access denied: {source}" if raise_error: raise QueryError(message) message_class = "danger" return {"message": message, "class": message_class} elif source not in settings.MAIN_SOURCES: message = f"Invalid source: {source}" if raise_error: raise QueryError(message) message_class = "danger" return {"message": message, "class": message_class} if source == "all": source = None # the next block will populate it if source: sources = [source] else: sources = list(settings.MAIN_SOURCES) if user.has_perm("core.restricted_sources"): for source_iter in settings.SOURCES_RESTRICTED: sources.append(source_iter) if "all" in sources: sources.remove("all") return sources def parse_sort(self, query_params): sort = None if "sorting" in query_params: sorting = query_params["sorting"] if sorting not in ("asc", "desc", "none"): message = "Invalid sort" message_class = "danger" return {"message": message, "class": message_class} if sorting == "asc": sort = "ascending" elif sorting == "desc": sort = "descending" return sort def parse_date_time(self, query_params): if set({"from_date", "to_date", "from_time", "to_time"}).issubset( query_params.keys() ): from_ts = f"{query_params['from_date']}T{query_params['from_time']}Z" to_ts = f"{query_params['to_date']}T{query_params['to_time']}Z" from_ts = datetime.strptime(from_ts, "%Y-%m-%dT%H:%MZ") to_ts = datetime.strptime(to_ts, "%Y-%m-%dT%H:%MZ") return (from_ts, to_ts) return (None, None) def parse_sentiment(self, query_params): sentiment = None if "check_sentiment" in query_params: if "sentiment_method" not in query_params: message = "No sentiment method" message_class = "danger" return {"message": message, "class": message_class} if "sentiment" in query_params: sentiment = query_params["sentiment"] try: sentiment = float(sentiment) except ValueError: message = "Sentiment is not a float" message_class = "danger" return {"message": message, "class": message_class} sentiment_method = query_params["sentiment_method"] return (sentiment_method, sentiment) def filter_blacklisted(self, user, response): """ Low level filter to take the raw search response and remove objects from it we want to keep secret. Does not return, the object is mutated in place. """ response["redacted"] = 0 response["exemption"] = None if user.is_superuser: response["exemption"] = True # is_anonymous = isinstance(user, AnonymousUser) # For every hit from ES for index, item in enumerate(list(response["hits"]["hits"])): # For every blacklisted type for blacklisted_type in settings.ELASTICSEARCH_BLACKLISTED.keys(): # Check this field we are matching exists if "_source" in item.keys(): data_index = "_source" elif "fields" in item.keys(): data_index = "fields" else: return False if blacklisted_type in item[data_index].keys(): content = item[data_index][blacklisted_type] # For every item in the blacklisted array for the type for blacklisted_item in settings.BLACKLISTED[blacklisted_type]: if blacklisted_item == str(content): # Remove the item if item in response["hits"]["hits"]: # Let the UI know something was redacted if ( "exemption" not in response["hits"]["hits"][index][data_index] ): response["redacted"] += 1 # Anonymous if user.is_anonymous: # Just set it to none so the index is not off response["hits"]["hits"][index] = None else: if not user.has_perm("core.bypass_blacklist"): response["hits"]["hits"][index] = None else: response["hits"]["hits"][index][data_index][ "exemption" ] = True # Actually get rid of all the things we set to None response["hits"]["hits"] = [hit for hit in response["hits"]["hits"] if hit] def query(self, user, search_query, **kwargs): # For time tracking start = time.process_time() if settings.CACHE: # Sort the keys so the hash is the same query_normalised = orjson.dumps(search_query, option=orjson.OPT_SORT_KEYS) hash = siphash(self.hash_key, query_normalised) cache_hit = r.get(f"query_cache.{user.id}.{hash}") if cache_hit: response = orjson.loads(cache_hit) time_took = (time.process_time() - start) * 1000 # Round to 3 significant figures time_took_rounded = round( time_took, 3 - int(floor(log10(abs(time_took)))) - 1 ) return { "object_list": response, "took": time_took_rounded, "cache": True, } response = self.run_query(user, search_query, **kwargs) # For Elasticsearch if isinstance(response, Exception): message = f"Error: {response.info['error']['root_cause'][0]['type']}" message_class = "danger" return {"message": message, "class": message_class} if len(response["hits"]["hits"]) == 0: message = "No results." message_class = "danger" return {"message": message, "class": message_class} # For Druid if "error" in response: if "errorMessage" in response: context = { "message": response["errorMessage"], "class": "danger", } return context else: return response if "took" in response: if response["took"] is None: return None # Removed for now, no point given we have restricted indexes # self.filter_blacklisted(user, response) # Parse the response response_parsed = self.parse(response) # Write cache if settings.CACHE: to_write_cache = orjson.dumps(response_parsed) r.set(f"query_cache.{user.id}.{hash}", to_write_cache) r.expire(f"query_cache.{user.id}.{hash}", settings.CACHE_TIMEOUT) time_took = (time.process_time() - start) * 1000 # Round to 3 significant figures time_took_rounded = round(time_took, 3 - int(floor(log10(abs(time_took)))) - 1) return {"object_list": response_parsed, "took": time_took_rounded} @abstractmethod def query_results(self, **kwargs): pass def process_results(self, response, **kwargs): if kwargs.get("annotate"): annotate_results(response) if kwargs.get("reverse"): response.reverse() if kwargs.get("dedup"): dedup_fields = kwargs.get("dedup_fields") if not dedup_fields: dedup_fields = ["msg", "nick", "ident", "host", "net", "channel"] response = dedup_list(response, dedup_fields) return response @abstractmethod def parse(self, response): pass