From 7c69c99b8f718a7694de5977132c93fba2ec49f3 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Mon, 13 Mar 2023 19:22:06 +0000 Subject: [PATCH] Remove DB and clean up references to it --- core/clients/aggregator.py | 7 +- core/clients/platform.py | 18 ++-- core/lib/db.py | 192 ------------------------------------- core/models.py | 27 ++++-- 4 files changed, 33 insertions(+), 211 deletions(-) diff --git a/core/clients/aggregator.py b/core/clients/aggregator.py index f374ee6..c075d96 100644 --- a/core/clients/aggregator.py +++ b/core/clients/aggregator.py @@ -1,5 +1,7 @@ from abc import ABC +import orjson + from core.clients.platforms.agora import AgoraClient from core.lib import notify from core.lib.money import money @@ -52,7 +54,7 @@ class AggregatorClient(ABC): if not transactions: return False - platforms = self.platforms + platforms = self.instance.platforms for transaction in transactions: transaction_id = transaction["transaction_id"] tx_obj = self.instance.get_transaction( @@ -73,6 +75,9 @@ class AggregatorClient(ABC): tx_cast, ) # New transaction + await notify.sendmsg( + f"New transaction: {orjson.dumps(tx_cast)}", title="New transaction" + ) await self.transaction(platforms, tx_obj) else: # Transaction exists diff --git a/core/clients/platform.py b/core/clients/platform.py index 731b0e2..23151e1 100644 --- a/core/clients/platform.py +++ b/core/clients/platform.py @@ -9,7 +9,7 @@ from aiocoingecko import AsyncCoinGeckoAPISession from django.conf import settings from core.clients.platforms.api.agoradesk import AgoraDesk -from core.lib import db, notify +from core.lib import notify from core.lib.antifraud import antifraud from core.lib.money import money from core.util import logs @@ -231,15 +231,14 @@ class LocalPlatformClient(ABC): if "data" not in messages["response"]: log.error(f"Data not in messages response: {messages['response']}") return False - ref_map = await db.get_ref_map() - open_tx = ref_map.keys() + open_tx = self.instance.trade_ids for message in messages["response"]["data"]["message_list"]: contact_id = str(message["contact_id"]) username = message["sender"]["username"] msg = message["msg"] if contact_id not in open_tx: continue - reference = await db.tx_to_ref(contact_id) + reference = self.instance.contact_id_to_reference(contact_id) if reference in messages_tmp: messages_tmp[reference].append([username, msg]) else: @@ -860,7 +859,6 @@ class LocalPlatformClient(ABC): """ reference = "".join(choices(ascii_uppercase, k=5)) reference = f"AGR-{reference}" - existing_ref = await db.r.get(f"trade.{trade_id}.reference") existing_ref = self.instance.contact_id_to_reference(trade_id) if not existing_ref: # to_store = { @@ -929,13 +927,13 @@ class LocalPlatformClient(ABC): :return: tuple of (platform, trade_id, reference, currency) """ platform, username = self.get_uid(uid) - refs = await db.get_refs() + refs = self.instance.references matching_trades = [] for reference in refs: - ref_data = await db.get_ref(reference) - tx_username = ref_data["buyer"] - trade_id = ref_data["id"] - currency = ref_data["currency"] + ref_data = self.instance.get_trade_by_reference(reference) + tx_username = ref_data.buyer + trade_id = ref_data.contact_id + currency = ref_data.currency if tx_username == username: to_append = (platform, trade_id, reference, currency) matching_trades.append(to_append) diff --git a/core/lib/db.py b/core/lib/db.py index 88fe502..e69de29 100644 --- a/core/lib/db.py +++ b/core/lib/db.py @@ -1,192 +0,0 @@ -from redis import asyncio as aioredis - -from core.util import logs - -log = logs.get_logger("db") - -r = aioredis.from_url("redis://redis:6379", db=0) # noqa - - -def convert(data): - """ - Recursively convert a dictionary. - """ - if isinstance(data, bytes): - return data.decode("ascii") - if isinstance(data, dict): - return dict(map(convert, data.items())) - if isinstance(data, tuple): - return map(convert, data) - if isinstance(data, list): - return list(map(convert, data)) - return data - - -async def get_refs(): - """ - Get all reference IDs for trades. - :return: list of trade IDs - :rtype: list - """ - r = aioredis.from_url("redis://redis:6379", db=0) - references = [] - ref_keys = await r.keys("trade.*.reference") - for key in ref_keys: - key_data = await r.get(key) - references.append(key_data) - return convert(references) - - -async def tx_to_ref(tx): - """ - Convert a trade ID to a reference. - :param tx: trade ID - :type tx: string - :return: reference - :rtype: string - """ - r = aioredis.from_url("redis://redis:6379", db=0) - refs = await get_refs() - for reference in refs: - ref_data = await r.hgetall(f"trade.{reference}") - ref_data = convert(ref_data) - if not ref_data: - continue - if ref_data["id"] == tx: - return reference - - -async def ref_to_tx(reference): - """ - Convert a reference to a trade ID. - :param reference: trade reference - :type reference: string - :return: trade ID - :rtype: string - """ - r = aioredis.from_url("redis://redis:6379", db=0) - ref_data = convert(await r.hgetall(f"trade.{reference}")) - if not ref_data: - return False - return ref_data["id"] - - -async def get_ref_map(): - """ - Get all reference IDs for trades. - :return: dict of references keyed by TXID - :rtype: dict - """ - r = aioredis.from_url("redis://redis:6379", db=0) - references = {} - ref_keys = await r.keys("trade.*.reference") - for key in ref_keys: - tx = convert(key).split(".")[1] - references[tx] = await r.get(key) - return convert(references) - - -async def get_ref(reference): - """ - Get the trade information for a reference. - :param reference: trade reference - :type reference: string - :return: dict of trade information - :rtype: dict - """ - r = aioredis.from_url("redis://redis:6379", db=0) - ref_data = await r.hgetall(f"trade.{reference}") - ref_data = convert(ref_data) - if "subclass" not in ref_data: - ref_data["subclass"] = "agora" - if not ref_data: - return False - return ref_data - - -async def get_tx(tx): - """ - Get the transaction information for a transaction ID. - :param reference: trade reference - :type reference: string - :return: dict of trade information - :rtype: dict - """ - r = aioredis.from_url("redis://redis:6379", db=0) - tx_data = await r.hgetall(f"tx.{tx}") - tx_data = convert(tx_data) - if not tx_data: - return False - return tx_data - - -async def get_subclass(reference): - r = aioredis.from_url("redis://redis:6379", db=0) - obj = await r.hget(f"trade.{reference}", "subclass") - subclass = convert(obj) - return subclass - - -async def del_ref(reference): - """ - Delete a given reference from the Redis database. - :param reference: trade reference to delete - :type reference: string - """ - r = aioredis.from_url("redis://redis:6379", db=0) - tx = await ref_to_tx(reference) - await r.delete(f"trade.{reference}") - await r.delete(f"trade.{tx}.reference") - - -async def cleanup(subclass, references): - """ - Reconcile the internal reference database with a given list of references. - Delete all internal references not present in the list and clean up artifacts. - :param references: list of references to reconcile against - :type references: list - """ - r = aioredis.from_url("redis://redis:6379", db=0) - messages = [] - ref_map = await get_ref_map() - for tx, reference in ref_map.items(): - if reference not in references: - if await get_subclass(reference) == subclass: - logmessage = ( - f"[{reference}] ({subclass}): Archiving trade reference. TX: {tx}" - ) - messages.append(logmessage) - log.info(logmessage) - await r.rename(f"trade.{tx}.reference", f"archive.trade.{tx}.reference") - await r.rename(f"trade.{reference}", f"archive.trade.{reference}") - return messages - - -async def find_trade(self, txid, currency, amount): - """ - Get a trade reference that matches the given currency and amount. - Only works if there is one result. - :param txid: Sink transaction ID - :param currency: currency - :param amount: amount - :type txid: string - :type currency: string - :type amount: int - :return: matching trade object or False - :rtype: dict or bool - """ - refs = await get_refs() - matching_refs = [] - # TODO: use get_ref_map in this function instead of calling get_ref multiple times - for ref in refs: - stored_trade = await get_ref(ref) - if stored_trade["currency"] == currency and float( - stored_trade["amount"] - ) == float(amount): - matching_refs.append(stored_trade) - if len(matching_refs) != 1: - log.error( - f"Find trade returned multiple results for TXID {txid}: {matching_refs}" - ) - return False - return matching_refs[0] diff --git a/core/models.py b/core/models.py index b697991..a163900 100644 --- a/core/models.py +++ b/core/models.py @@ -85,13 +85,11 @@ class Aggregator(models.Model): platforms=platform, enabled=True, ) - print("ADS", ads) for ad in ads: for aggregator in ad.aggregators.all(): if aggregator not in aggregators: aggregators.append(aggregator) - print("RET", aggregators) return aggregators @property @@ -106,13 +104,11 @@ class Aggregator(models.Model): aggregators=self, enabled=True, ) - print("ADS", ads) for ad in ads: for platform in ad.platforms.all(): if platform not in platforms: platforms.append(platform) - print("RET", platforms) return platforms @classmethod @@ -260,6 +256,18 @@ class Platform(models.Model): return references + @property + def trade_ids(self): + """ + Get trade IDs of all our trades that are open. + """ + references = [] + our_trades = Trade.objects.filter(platform=self, open=True) + for trade in our_trades: + references.append(trade.contact_id) + + return references + def get_trade_by_reference(self, reference): return Trade.objects.filter( platform=self, @@ -286,6 +294,13 @@ class Platform(models.Model): return None return trade.reference + def get_trade_by_trade_id(self, trade_id): + return Trade.objects.filter( + platform=self, + open=True, + contact_id=trade_id, + ).first() + def new_trade(self, trade_cast): trade = Trade.objects.create( platform=self, @@ -315,13 +330,11 @@ class Platform(models.Model): aggregators=aggregator, enabled=True, ) - print("ADS", ads) for ad in ads: for platform in ad.platforms.all(): if platform not in platforms: platforms.append(platform) - print("RET", platforms) return platforms @property @@ -336,13 +349,11 @@ class Platform(models.Model): platforms=self, enabled=True, ) - print("ADS", ads) for ad in ads: for aggregator in ad.aggregators.all(): if aggregator not in aggregators: aggregators.append(aggregator) - print("RET", aggregators) return aggregators