import random import string import time from abc import ABC, abstractmethod from math import floor, log10 import orjson from django.conf import settings from siphashc import siphash from core import r from core.db.processing import annotate_results from core.util import logs def remove_defaults(query_params): for field, value in list(query_params.items()): if field in settings.DRILLDOWN_DEFAULT_PARAMS: if value == settings.DRILLDOWN_DEFAULT_PARAMS[field]: del query_params[field] def add_defaults(query_params): for field, value in settings.DRILLDOWN_DEFAULT_PARAMS.items(): if field not in query_params: query_params[field] = value def dedup_list(data, check_keys): """ Remove duplicate dictionaries from list. """ seen = set() out = [] dup_count = 0 for x in data: dedupeKey = tuple(x[k] for k in check_keys if k in x) if dedupeKey in seen: dup_count += 1 continue if dup_count > 0: out.append({"type": "control", "hidden": dup_count}) dup_count = 0 out.append(x) seen.add(dedupeKey) if dup_count > 0: out.append({"type": "control", "hidden": dup_count}) return out class StorageBackend(ABC): def __init__(self, name): self.log = logs.get_logger(name) self.log.info(f"Initialising storage backend {name}") self.initialise_caching() # self.initialise() @abstractmethod def initialise(self, **kwargs): pass def initialise_caching(self): hash_key = r.get("cache_hash_key") if not hash_key: letters = string.ascii_lowercase hash_key = "".join(random.choice(letters) for i in range(16)) self.log.debug(f"Created new hash key: {hash_key}") r.set("cache_hash_key", hash_key) else: hash_key = hash_key.decode("ascii") self.log.debug(f"Decoded hash key: {hash_key}") self.hash_key = hash_key @abstractmethod def construct_query(self, **kwargs): pass def parse_query(self, query_params, tags, size, custom_query, add_bool, **kwargs): query_created = False if "query" in query_params: query = query_params["query"] search_query = self.construct_query(query, size, **kwargs) query_created = True else: if custom_query: search_query = custom_query else: search_query = self.construct_query(None, size, blank=True, **kwargs) if tags: # Get a blank search query if not query_created: search_query = self.construct_query(None, size, blank=True, **kwargs) query_created = True for item in tags: for tagname, tagvalue in item.items(): add_bool.append({tagname: tagvalue}) bypass_check = kwargs.get("bypass_check", False) if not bypass_check: valid = self.check_valid_query(query_params, custom_query, **kwargs) if isinstance(valid, dict): return valid return search_query def check_valid_query(self, query_params, custom_query): required_any = ["query", "tags"] if not any([field in query_params.keys() for field in required_any]): if not custom_query: message = "Empty query!" message_class = "warning" return {"message": message, "class": message_class} @abstractmethod def run_query(self, **kwargs): pass def filter_blacklisted(self, user, response): """ Low level filter to take the raw search response and remove objects from it we want to keep secret. Does not return, the object is mutated in place. """ response["redacted"] = 0 response["exemption"] = None if user.is_superuser: response["exemption"] = True # is_anonymous = isinstance(user, AnonymousUser) # For every hit from ES for index, item in enumerate(list(response["hits"]["hits"])): # For every blacklisted type for blacklisted_type in settings.ELASTICSEARCH_BLACKLISTED.keys(): # Check this field we are matching exists if "_source" in item.keys(): data_index = "_source" elif "fields" in item.keys(): data_index = "fields" else: return False if blacklisted_type in item[data_index].keys(): content = item[data_index][blacklisted_type] # For every item in the blacklisted array for the type for blacklisted_item in settings.BLACKLISTED[blacklisted_type]: if blacklisted_item == str(content): # Remove the item if item in response["hits"]["hits"]: # Let the UI know something was redacted if ( "exemption" not in response["hits"]["hits"][index][data_index] ): response["redacted"] += 1 # Anonymous if user.is_anonymous: # Just set it to none so the index is not off response["hits"]["hits"][index] = None else: if not user.has_perm("core.bypass_blacklist"): response["hits"]["hits"][index] = None else: response["hits"]["hits"][index][data_index][ "exemption" ] = True # Actually get rid of all the things we set to None response["hits"]["hits"] = [hit for hit in response["hits"]["hits"] if hit] def query(self, user, search_query, **kwargs): # For time tracking start = time.process_time() if settings.CACHE: # Sort the keys so the hash is the same query_normalised = orjson.dumps(search_query, option=orjson.OPT_SORT_KEYS) hash = siphash(self.hash_key, query_normalised) cache_hit = r.get(f"query_cache.{user.id}.{hash}") if cache_hit: response = orjson.loads(cache_hit) time_took = (time.process_time() - start) * 1000 # Round to 3 significant figures time_took_rounded = round( time_took, 3 - int(floor(log10(abs(time_took)))) - 1 ) return { "object_list": response, "took": time_took_rounded, "cache": True, } response = self.run_query(user, search_query, **kwargs) # For Elasticsearch if isinstance(response, Exception): message = f"Error: {response.info['error']['root_cause'][0]['type']}" message_class = "danger" return {"message": message, "class": message_class} if len(response["hits"]["hits"]) == 0: message = "No results." message_class = "danger" return {"message": message, "class": message_class} # For Druid if "error" in response: if "errorMessage" in response: context = { "message": response["errorMessage"], "class": "danger", } return context else: return response if "took" in response: if response["took"] is None: return None # Removed for now, no point given we have restricted indexes # self.filter_blacklisted(user, response) # Parse the response response_parsed = self.parse(response) # Write cache if settings.CACHE: to_write_cache = orjson.dumps(response_parsed) r.set(f"query_cache.{user.id}.{hash}", to_write_cache) r.expire(f"query_cache.{user.id}.{hash}", settings.CACHE_TIMEOUT) time_took = (time.process_time() - start) * 1000 # Round to 3 significant figures time_took_rounded = round(time_took, 3 - int(floor(log10(abs(time_took)))) - 1) return {"object_list": response_parsed, "took": time_took_rounded} @abstractmethod def query_results(self, **kwargs): pass def process_results(self, response, **kwargs): if kwargs.get("annotate"): annotate_results(response) if kwargs.get("reverse"): response.reverse() if kwargs.get("dedup"): dedup_fields = kwargs.get("dedup_fields") if not dedup_fields: dedup_fields = ["msg", "nick", "ident", "host", "net", "channel"] response = dedup_list(response, dedup_fields) return response @abstractmethod def parse(self, response): pass