diff --git a/handler/db.py b/handler/db.py index 088c684..83d185c 100644 --- a/handler/db.py +++ b/handler/db.py @@ -3,6 +3,131 @@ from redis import StrictRedis # Project imports from settings import settings +import util + +log = util.get_logger("db") # Define the Redis endpoint to the socket r = StrictRedis(unix_socket_path=settings.DB.RedisSocket, db=int(settings.DB.DB)) + + +def get_refs(): + """ + Get all reference IDs for trades. + :return: list of trade IDs + :rtype: list + """ + references = [] + ref_keys = r.keys("trade.*.reference") + for key in ref_keys: + references.append(r.get(key)) + return util.convert(references) + + +def tx_to_ref(self, tx): + """ + Convert a trade ID to a reference. + :param tx: trade ID + :type tx: string + :return: reference + :rtype: string + """ + refs = get_refs() + for reference in refs: + ref_data = util.convert(r.hgetall(f"trade.{reference}")) + if not ref_data: + continue + if ref_data["id"] == tx: + return reference + + +def ref_to_tx(self, reference): + """ + Convert a reference to a trade ID. + :param reference: trade reference + :type reference: string + :return: trade ID + :rtype: string + """ + ref_data = util.convert(r.hgetall(f"trade.{reference}")) + if not ref_data: + return False + return ref_data["id"] + + +def get_ref_map(): + """ + Get all reference IDs for trades. + :return: dict of references keyed by TXID + :rtype: dict + """ + references = {} + ref_keys = r.keys("trade.*.reference") + for key in ref_keys: + tx = util.convert(key).split(".")[1] + references[tx] = r.get(key) + return util.convert(references) + + +def get_ref(reference): + """ + Get the trade information for a reference. + :param reference: trade reference + :type reference: string + :return: dict of trade information + :rtype: dict + """ + ref_data = r.hgetall(f"trade.{reference}") + ref_data = util.convert(ref_data) + if "subclass" not in ref_data: + ref_data["subclass"] = "agora" + if not ref_data: + return False + return ref_data + + +def get_tx(tx): + """ + Get the transaction information for a transaction ID. + :param reference: trade reference + :type reference: string + :return: dict of trade information + :rtype: dict + """ + tx_data = r.hgetall(f"tx.{tx}") + tx_data = util.convert(tx_data) + if not tx_data: + return False + return tx_data + + +def get_subclass(reference): + obj = r.hget(f"trade.{reference}", "subclass") + subclass = util.convert(obj) + return subclass + + +def del_ref(reference): + """ + Delete a given reference from the Redis database. + :param reference: trade reference to delete + :type reference: string + """ + tx = ref_to_tx(reference) + r.delete(f"trade.{reference}") + r.delete(f"trade.{tx}.reference") + + +def cleanup(subclass, references): + """ + Reconcile the internal reference database with a given list of references. + Delete all internal references not present in the list and clean up artifacts. + :param references: list of references to reconcile against + :type references: list + """ + for tx, reference in get_ref_map().items(): + if reference not in references: + if get_subclass(reference) == subclass: + log.info(f"Archiving trade reference: {reference} / TX: {tx}") + r.rename(f"trade.{tx}.reference", f"archive.trade.{tx}.reference") + r.rename(f"trade.{reference}", f"archive.trade.{reference}") diff --git a/handler/sources/local.py b/handler/sources/local.py index 3b86112..4380bc2 100644 --- a/handler/sources/local.py +++ b/handler/sources/local.py @@ -13,6 +13,7 @@ from settings import settings import util from lib.agoradesk_py import AgoraDesk from lib.localbitcoins_py import LocalBitcoins +import db class Local(util.Base): @@ -89,7 +90,7 @@ class Local(util.Base): if dash is False: return False for contact_id, contact in dash.items(): - reference = self.tx.tx_to_ref(contact_id) + reference = db.tx_to_ref(contact_id) buyer = contact["data"]["buyer"]["username"] amount = contact["data"]["amount"] if self.platform == "agora": @@ -123,7 +124,7 @@ class Local(util.Base): if not dash.items(): return for contact_id, contact in dash.items(): - reference = self.tx.tx_to_ref(str(contact_id)) + reference = db.tx_to_ref(str(contact_id)) if reference: current_trades.append(reference) buyer = contact["data"]["buyer"]["username"] @@ -170,7 +171,7 @@ class Local(util.Base): self.last_dash.remove(ref) if reference and reference not in current_trades: current_trades.append(reference) - self.tx.cleanup(self.platform, current_trades) + db.cleanup(self.platform, current_trades) def got_recent_messages(self, messages, send_irc=True): """ @@ -184,14 +185,14 @@ class Local(util.Base): if "data" not in messages["response"]: self.log.error(f"Data not in messages response: {messages['response']}") return False - open_tx = self.tx.get_ref_map().keys() + open_tx = db.get_ref_map().keys() for message in messages["response"]["data"]["message_list"]: contact_id = message["contact_id"] username = message["sender"]["username"] msg = message["msg"] if contact_id not in open_tx: continue - reference = self.tx.tx_to_ref(contact_id) + reference = db.tx_to_ref(contact_id) if reference in messages_tmp: messages_tmp[reference].append([username, msg]) else: diff --git a/handler/tests/test_agora.py b/handler/tests/test_agora.py index 1596e00..9349cb7 100644 --- a/handler/tests/test_agora.py +++ b/handler/tests/test_agora.py @@ -1,5 +1,7 @@ from unittest import TestCase from unittest.mock import MagicMock, patch + +# from twisted.internet.defer import inlineCallbacks from json import loads from copy import deepcopy diff --git a/handler/tests/test_lbtc.py b/handler/tests/test_lbtc.py index 9b0bfaa..a50b711 100644 --- a/handler/tests/test_lbtc.py +++ b/handler/tests/test_lbtc.py @@ -1,5 +1,7 @@ from unittest import TestCase from unittest.mock import MagicMock, patch + +# from twisted.internet.defer import inlineCallbacks from json import loads from copy import deepcopy diff --git a/handler/tests/test_transactions.py b/handler/tests/test_transactions.py index b78d7e8..2af04e4 100644 --- a/handler/tests/test_transactions.py +++ b/handler/tests/test_transactions.py @@ -30,10 +30,10 @@ class TestTransactions(TestCase): } # Mock redis calls - transactions.r.hgetall = self.mock_hgetall - transactions.r.hmset = self.mock_hmset - transactions.r.keys = self.mock_keys - transactions.r.get = self.mock_get + transactions.db.r.hgetall = self.mock_hgetall + transactions.db.r.hmset = self.mock_hmset + transactions.db.r.keys = self.mock_keys + transactions.db.r.get = self.mock_get # Mock some callbacks self.transactions.irc = MagicMock() diff --git a/handler/transactions.py b/handler/transactions.py index e935f98..e3150e3 100644 --- a/handler/transactions.py +++ b/handler/transactions.py @@ -13,7 +13,7 @@ import logging # Project imports from settings import settings -from db import r +import db import util # TODO: secure ES traffic properly @@ -131,7 +131,7 @@ class Transactions(util.Base): # Split the reference into parts ref_split = reference.split(" ") # Get all existing references - existing_refs = self.get_refs() + existing_refs = db.get_refs() # Get all parts of the given reference split that match the existing references stored_trade_reference = set(existing_refs).intersection(set(ref_split)) if len(stored_trade_reference) > 1: @@ -174,7 +174,7 @@ class Transactions(util.Base): return stored_trade def normal_lookup(self, stored_trade_reference, reference, currency, amount): - stored_trade = self.get_ref(stored_trade_reference) + stored_trade = db.get_ref(stored_trade_reference) if not stored_trade: self.log.info(f"No reference in DB for {reference}") self.irc.sendmsg(f"No reference in DB for {reference}") @@ -227,7 +227,7 @@ class Transactions(util.Base): :param bank_sender: the sender name from the bank """ key = f"namemap.{platform}.{platform_buyer}" - r.sadd(key, bank_sender) + db.r.sadd(key, bank_sender) def get_previous_senders(self, platform, platform_buyer): """ @@ -238,7 +238,7 @@ class Transactions(util.Base): :rtype: set """ key = f"namemap.{platform}.{platform_buyer}" - senders = r.smembers(key) + senders = db.r.smembers(key) if not senders: return None senders = util.convert(senders) @@ -273,10 +273,10 @@ class Transactions(util.Base): :param tx: the transaction ID :param reference: the trade reference """ - stored_trade = self.get_ref(reference) + stored_trade = db.get_ref(reference) if not stored_trade: return None - stored_tx = self.get_tx(tx) + stored_tx = db.get_tx(tx) if not stored_tx: return None bank_sender = stored_tx["sender"] @@ -292,11 +292,11 @@ class Transactions(util.Base): Update a trade to point to a given transaction ID. Return False if the trade already has a mapped transaction. """ - existing_tx = r.hget(f"trade.{reference}", "tx") + existing_tx = db.r.hget(f"trade.{reference}", "tx") if existing_tx is None: return None elif existing_tx == b"": - r.hset(f"trade.{reference}", "tx", txid) + db.r.hset(f"trade.{reference}", "tx", txid) return True else: # Already a mapped transaction return False @@ -329,7 +329,7 @@ class Transactions(util.Base): "currency": currency, "sender": sender, } - r.hmset(f"tx.{txid}", to_store) + db.r.hmset(f"tx.{txid}", to_store) self.log.info(f"Transaction processed: {dumps(to_store, indent=2)}") self.irc.sendmsg(f"AUTO Incoming transaction on {subclass}: {txid} {amount}{currency} ({reference})") @@ -390,7 +390,7 @@ class Transactions(util.Base): self.release_funds(stored_trade["id"], stored_trade["reference"]) def release_funds(self, trade_id, reference): - stored_trade = self.get_ref(reference) + stored_trade = db.get_ref(reference) platform = stored_trade["subclass"] logmessage = f"All checks passed, releasing funds for {trade_id} {reference}" self.log.info(logmessage) @@ -422,11 +422,11 @@ class Transactions(util.Base): Map a trade to a transaction and release if no other TX is mapped to the same trade. """ - stored_trade = self.get_ref(reference) + stored_trade = db.get_ref(reference) if not stored_trade: self.log.error(f"Could not get stored trade for {reference}.") return None - tx_obj = self.get_tx(tx) + tx_obj = db.get_tx(tx) if not tx_obj: self.log.error(f"Could not get TX for {tx}.") return None @@ -453,10 +453,10 @@ class Transactions(util.Base): :return: tuple of (platform, trade_id, reference, currency) """ platform, username = self.get_uid(uid) - refs = self.get_refs() + refs = db.get_refs() matching_trades = [] for reference in refs: - ref_data = self.get_ref(reference) + ref_data = db.get_ref(reference) tx_platform = ref_data["subclass"] tx_username = ref_data["buyer"] trade_id = ref_data["id"] @@ -553,7 +553,7 @@ class Transactions(util.Base): """ reference = "".join(choices(ascii_uppercase, k=5)) reference = f"PGN-{reference}" - existing_ref = r.get(f"trade.{trade_id}.reference") + existing_ref = db.r.get(f"trade.{trade_id}.reference") if not existing_ref: to_store = { "id": trade_id, @@ -568,8 +568,8 @@ class Transactions(util.Base): "subclass": subclass, } self.log.info(f"Storing trade information: {str(to_store)}") - r.hmset(f"trade.{reference}", to_store) - r.set(f"trade.{trade_id}.reference", reference) + db.r.hmset(f"trade.{reference}", to_store) + db.r.set(f"trade.{trade_id}.reference", reference) self.irc.sendmsg(f"Generated reference for {trade_id}: {reference}") self.ux.notify.notify_new_trade(amount, currency) uid = self.create_uid(subclass, buyer) @@ -599,11 +599,11 @@ class Transactions(util.Base): :return: matching trade object or False :rtype: dict or bool """ - refs = self.get_refs() + refs = db.get_refs() matching_refs = [] # TODO: use get_ref_map in this function instead of calling get_ref multiple times for ref in refs: - stored_trade = self.get_ref(ref) + stored_trade = db.get_ref(ref) if stored_trade["currency"] == currency and float(stored_trade["amount"]) == float(amount): matching_refs.append(stored_trade) if len(matching_refs) != 1: @@ -611,119 +611,6 @@ class Transactions(util.Base): return False return matching_refs[0] - def get_refs(self): - """ - Get all reference IDs for trades. - :return: list of trade IDs - :rtype: list - """ - references = [] - ref_keys = r.keys("trade.*.reference") - for key in ref_keys: - references.append(r.get(key)) - return util.convert(references) - - def get_ref_map(self): - """ - Get all reference IDs for trades. - :return: dict of references keyed by TXID - :rtype: dict - """ - references = {} - ref_keys = r.keys("trade.*.reference") - for key in ref_keys: - tx = util.convert(key).split(".")[1] - references[tx] = r.get(key) - return util.convert(references) - - def get_ref(self, reference): - """ - Get the trade information for a reference. - :param reference: trade reference - :type reference: string - :return: dict of trade information - :rtype: dict - """ - ref_data = r.hgetall(f"trade.{reference}") - ref_data = util.convert(ref_data) - if "subclass" not in ref_data: - ref_data["subclass"] = "agora" - if not ref_data: - return False - return ref_data - - def get_tx(self, tx): - """ - Get the transaction information for a transaction ID. - :param reference: trade reference - :type reference: string - :return: dict of trade information - :rtype: dict - """ - tx_data = r.hgetall(f"tx.{tx}") - tx_data = util.convert(tx_data) - if not tx_data: - return False - return tx_data - - def get_subclass(self, reference): - obj = r.hget(f"trade.{reference}", "subclass") - subclass = util.convert(obj) - return subclass - - def del_ref(self, reference): - """ - Delete a given reference from the Redis database. - :param reference: trade reference to delete - :type reference: string - """ - tx = self.ref_to_tx(reference) - r.delete(f"trade.{reference}") - r.delete(f"trade.{tx}.reference") - - def cleanup(self, subclass, references): - """ - Reconcile the internal reference database with a given list of references. - Delete all internal references not present in the list and clean up artifacts. - :param references: list of references to reconcile against - :type references: list - """ - for tx, reference in self.get_ref_map().items(): - if reference not in references: - if self.get_subclass(reference) == subclass: - self.log.info(f"Archiving trade reference: {reference} / TX: {tx}") - r.rename(f"trade.{tx}.reference", f"archive.trade.{tx}.reference") - r.rename(f"trade.{reference}", f"archive.trade.{reference}") - - def tx_to_ref(self, tx): - """ - Convert a trade ID to a reference. - :param tx: trade ID - :type tx: string - :return: reference - :rtype: string - """ - refs = self.get_refs() - for reference in refs: - ref_data = util.convert(r.hgetall(f"trade.{reference}")) - if not ref_data: - continue - if ref_data["id"] == tx: - return reference - - def ref_to_tx(self, reference): - """ - Convert a reference to a trade ID. - :param reference: trade reference - :type reference: string - :return: trade ID - :rtype: string - """ - ref_data = util.convert(r.hgetall(f"trade.{reference}")) - if not ref_data: - return False - return ref_data["id"] - @inlineCallbacks def get_total_usd(self): """