neptune/core/db/__init__.py

441 lines
16 KiB
Python

import random
import string
import time
from abc import ABC, abstractmethod
from math import floor, log10
import orjson
from django.conf import settings
from siphashc import siphash
from core import r
from core.db.processing import annotate_results
from core.util import logs
def remove_defaults(query_params):
for field, value in list(query_params.items()):
if field in settings.DRILLDOWN_DEFAULT_PARAMS:
if value == settings.DRILLDOWN_DEFAULT_PARAMS[field]:
del query_params[field]
def add_defaults(query_params):
for field, value in settings.DRILLDOWN_DEFAULT_PARAMS.items():
if field not in query_params:
query_params[field] = value
def dedup_list(data, check_keys):
"""
Remove duplicate dictionaries from list.
"""
seen = set()
out = []
dup_count = 0
for x in data:
dedupeKey = tuple(x[k] for k in check_keys if k in x)
if dedupeKey in seen:
dup_count += 1
continue
if dup_count > 0:
out.append({"type": "control", "hidden": dup_count})
dup_count = 0
out.append(x)
seen.add(dedupeKey)
if dup_count > 0:
out.append({"type": "control", "hidden": dup_count})
return out
class StorageBackend(ABC):
def __init__(self, name):
self.log = logs.get_logger(name)
self.log.info(f"Initialising storage backend {name}")
self.initialise_caching()
# self.initialise()
@abstractmethod
def initialise(self, **kwargs):
pass
def initialise_caching(self):
hash_key = r.get("cache_hash_key")
if not hash_key:
letters = string.ascii_lowercase
hash_key = "".join(random.choice(letters) for i in range(16))
self.log.debug(f"Created new hash key: {hash_key}")
r.set("cache_hash_key", hash_key)
else:
hash_key = hash_key.decode("ascii")
self.log.debug(f"Decoded hash key: {hash_key}")
self.hash_key = hash_key
@abstractmethod
def construct_query(self, **kwargs):
pass
def parse_query(self, query_params, tags, size, custom_query, add_bool, **kwargs):
query_created = False
if "query" in query_params:
query = query_params["query"]
search_query = self.construct_query(query, size, **kwargs)
query_created = True
else:
if custom_query:
search_query = custom_query
else:
search_query = self.construct_query(None, size, blank=True, **kwargs)
if tags:
# Get a blank search query
if not query_created:
search_query = self.construct_query(None, size, blank=True, **kwargs)
query_created = True
for item in tags:
for tagname, tagvalue in item.items():
add_bool.append({tagname: tagvalue})
bypass_check = kwargs.get("bypass_check", False)
if not bypass_check:
valid = self.check_valid_query(query_params, custom_query, **kwargs)
if isinstance(valid, dict):
return valid
return search_query
def check_valid_query(self, query_params, custom_query):
required_any = ["query", "tags"]
if not any([field in query_params.keys() for field in required_any]):
if not custom_query:
message = "Empty query!"
message_class = "warning"
return {"message": message, "class": message_class}
@abstractmethod
def run_query(self, **kwargs):
pass
def filter_blacklisted(self, user, response):
"""
Low level filter to take the raw search response and remove
objects from it we want to keep secret.
Does not return, the object is mutated in place.
"""
response["redacted"] = 0
response["exemption"] = None
if user.is_superuser:
response["exemption"] = True
# is_anonymous = isinstance(user, AnonymousUser)
# For every hit from ES
for index, item in enumerate(list(response["hits"]["hits"])):
# For every blacklisted type
for blacklisted_type in settings.ELASTICSEARCH_BLACKLISTED.keys():
# Check this field we are matching exists
if "_source" in item.keys():
data_index = "_source"
elif "fields" in item.keys():
data_index = "fields"
else:
return False
if blacklisted_type in item[data_index].keys():
content = item[data_index][blacklisted_type]
# For every item in the blacklisted array for the type
for blacklisted_item in settings.BLACKLISTED[blacklisted_type]:
if blacklisted_item == str(content):
# Remove the item
if item in response["hits"]["hits"]:
# Let the UI know something was redacted
if (
"exemption"
not in response["hits"]["hits"][index][data_index]
):
response["redacted"] += 1
# Anonymous
if user.is_anonymous:
# Just set it to none so the index is not off
response["hits"]["hits"][index] = None
else:
if not user.has_perm("core.bypass_blacklist"):
response["hits"]["hits"][index] = None
else:
response["hits"]["hits"][index][data_index][
"exemption"
] = True
# Actually get rid of all the things we set to None
response["hits"]["hits"] = [hit for hit in response["hits"]["hits"] if hit]
def add_bool(self, search_query, add_bool):
"""
Add the specified boolean matches to search query.
"""
if not add_bool:
return
for item in add_bool:
search_query["query"]["bool"]["must"].append({"match_phrase": item})
def add_top(self, search_query, add_top, negative=False):
"""
Merge add_top with the base of the search_query.
"""
if not add_top:
return
if negative:
for item in add_top:
if "must_not" in search_query["query"]["bool"]:
search_query["query"]["bool"]["must_not"].append(item)
else:
search_query["query"]["bool"]["must_not"] = [item]
else:
for item in add_top:
if "query" not in search_query:
search_query["query"] = {"bool": {"must": []}}
search_query["query"]["bool"]["must"].append(item)
def schedule_check_aggregations(self, rule_object, result_map):
"""
Check the results of a scheduled query for aggregations.
"""
if rule_object.aggs is None:
return result_map
for index, (meta, result) in result_map.items():
# Default to true, if no aggs are found, we still want to match
match = True
for agg_name, (operator, number) in rule_object.aggs.items():
if agg_name in meta["aggs"]:
agg_value = meta["aggs"][agg_name]["value"]
# TODO: simplify this, match is default to True
if operator == ">":
if agg_value > number:
match = True
else:
match = False
elif operator == "<":
if agg_value < number:
match = True
else:
match = False
elif operator == "=":
if agg_value == number:
match = True
else:
match = False
else:
match = False
else:
# No aggregation found, but it is required
match = False
result_map[index][0]["aggs"][agg_name]["match"] = match
return result_map
def query(self, user, search_query, **kwargs):
# For time tracking
start = time.process_time()
if settings.CACHE:
# Sort the keys so the hash is the same
query_normalised = orjson.dumps(search_query, option=orjson.OPT_SORT_KEYS)
hash = siphash(self.hash_key, query_normalised)
cache_hit = r.get(f"query_cache.{user.id}.{hash}")
if cache_hit:
response = orjson.loads(cache_hit)
time_took = (time.process_time() - start) * 1000
# Round to 3 significant figures
time_took_rounded = round(
time_took, 3 - int(floor(log10(abs(time_took)))) - 1
)
return {
"object_list": response,
"took": time_took_rounded,
"cache": True,
}
print("S2", search_query)
response = self.run_query(user, search_query, **kwargs)
# For Elasticsearch
if isinstance(response, Exception):
message = f"Error: {response.info['error']['root_cause'][0]['type']}"
message_class = "danger"
return {"message": message, "class": message_class}
if "took" in response:
if response["took"] is None:
return None
if "error" in response:
message = f"Error: {response['error']}"
message_class = "danger"
time_took = (time.process_time() - start) * 1000
# Round to 3 significant figures
time_took_rounded = round(
time_took, 3 - int(floor(log10(abs(time_took)))) - 1
)
return {
"message": message,
"class": message_class,
"took": time_took_rounded,
}
elif len(response["hits"]["hits"]) == 0:
message = "No results."
message_class = "danger"
time_took = (time.process_time() - start) * 1000
# Round to 3 significant figures
time_took_rounded = round(
time_took, 3 - int(floor(log10(abs(time_took)))) - 1
)
return {
"message": message,
"class": message_class,
"took": time_took_rounded,
}
# For Druid
elif "error" in response:
if "errorMessage" in response:
context = {
"message": response["errorMessage"],
"class": "danger",
}
return context
else:
return response
# Removed for now, no point given we have restricted indexes
# self.filter_blacklisted(user, response)
# Parse the response
response_parsed = self.parse(response)
# Write cache
if settings.CACHE:
to_write_cache = orjson.dumps(response_parsed)
r.set(f"query_cache.{user.id}.{hash}", to_write_cache)
r.expire(f"query_cache.{user.id}.{hash}", settings.CACHE_TIMEOUT)
time_took = (time.process_time() - start) * 1000
# Round to 3 significant figures
time_took_rounded = round(time_took, 3 - int(floor(log10(abs(time_took)))) - 1)
return {"object_list": response_parsed, "took": time_took_rounded}
def construct_context_query(
self, index, net, channel, src, num, size, type=None, nicks=None
):
# Get the initial query
query = self.construct_query(None, size, blank=True)
extra_must = []
extra_should = []
extra_should2 = []
if num:
extra_must.append({"match_phrase": {"num": num}})
if net:
extra_must.append({"match_phrase": {"net": net}})
if channel:
extra_must.append({"match": {"channel": channel}})
if nicks:
for nick in nicks:
extra_should2.append({"match": {"nick": nick}})
types = ["msg", "notice", "action", "kick", "topic", "mode"]
fields = [
"nick",
"ident",
"host",
"channel",
"ts",
"msg",
"type",
"net",
"src",
"tokens",
]
query["fields"] = fields
if index == "internal":
fields.append("mtype")
if channel == "*status" or type == "znc":
if {"match": {"channel": channel}} in extra_must:
extra_must.remove({"match": {"channel": channel}})
extra_should2 = []
# Type is one of msg or notice
# extra_should.append({"match": {"mtype": "msg"}})
# extra_should.append({"match": {"mtype": "notice"}})
extra_should.append({"match": {"type": "znc"}})
extra_should.append({"match": {"type": "self"}})
extra_should2.append({"match": {"type": "znc"}})
extra_should2.append({"match": {"nick": channel}})
elif type == "auth":
if {"match": {"channel": channel}} in extra_must:
extra_must.remove({"match": {"channel": channel}})
extra_should2 = []
extra_should2.append({"match": {"nick": channel}})
# extra_should2.append({"match": {"mtype": "msg"}})
# extra_should2.append({"match": {"mtype": "notice"}})
extra_should.append({"match": {"type": "query"}})
extra_should2.append({"match": {"type": "self"}})
extra_should.append({"match": {"nick": channel}})
else:
for ctype in types:
extra_should.append({"match": {"mtype": ctype}})
else:
for ctype in types:
extra_should.append({"match": {"type": ctype}})
# query = {
# "index": index,
# "limit": size,
# "query": {
# "bool": {
# "must": [
# # {"equals": {"src": src}},
# # {
# # "bool": {
# # "should": [*extra_should],
# # }
# # },
# # {
# # "bool": {
# # "should": [*extra_should2],
# # }
# # },
# *extra_must,
# ]
# }
# },
# "fields": fields,
# # "_source": False,
# }
if extra_must:
for x in extra_must:
query["query"]["bool"]["must"].append(x)
if extra_should:
query["query"]["bool"]["must"].append({"bool": {"should": [*extra_should]}})
if extra_should2:
query["query"]["bool"]["must"].append(
{"bool": {"should": [*extra_should2]}}
)
return query
@abstractmethod
def query_results(self, **kwargs):
pass
def process_results(self, response, **kwargs):
if kwargs.get("annotate"):
annotate_results(response)
if kwargs.get("reverse"):
response.reverse()
if kwargs.get("dedup"):
dedup_fields = kwargs.get("dedup_fields")
if not dedup_fields:
dedup_fields = ["msg", "nick", "ident", "host", "net", "channel"]
response = dedup_list(response, dedup_fields)
return response
@abstractmethod
def parse(self, response):
pass