Move reference handling code to DB

This commit is contained in:
Mark Veidemanis 2022-05-05 13:01:24 +01:00
parent 6504c440e0
commit 22520c8224
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
6 changed files with 159 additions and 142 deletions

View File

@ -3,6 +3,131 @@ from redis import StrictRedis
# Project imports # Project imports
from settings import settings from settings import settings
import util
log = util.get_logger("db")
# Define the Redis endpoint to the socket # Define the Redis endpoint to the socket
r = StrictRedis(unix_socket_path=settings.DB.RedisSocket, db=int(settings.DB.DB)) r = StrictRedis(unix_socket_path=settings.DB.RedisSocket, db=int(settings.DB.DB))
def get_refs():
"""
Get all reference IDs for trades.
:return: list of trade IDs
:rtype: list
"""
references = []
ref_keys = r.keys("trade.*.reference")
for key in ref_keys:
references.append(r.get(key))
return util.convert(references)
def tx_to_ref(self, tx):
"""
Convert a trade ID to a reference.
:param tx: trade ID
:type tx: string
:return: reference
:rtype: string
"""
refs = get_refs()
for reference in refs:
ref_data = util.convert(r.hgetall(f"trade.{reference}"))
if not ref_data:
continue
if ref_data["id"] == tx:
return reference
def ref_to_tx(self, reference):
"""
Convert a reference to a trade ID.
:param reference: trade reference
:type reference: string
:return: trade ID
:rtype: string
"""
ref_data = util.convert(r.hgetall(f"trade.{reference}"))
if not ref_data:
return False
return ref_data["id"]
def get_ref_map():
"""
Get all reference IDs for trades.
:return: dict of references keyed by TXID
:rtype: dict
"""
references = {}
ref_keys = r.keys("trade.*.reference")
for key in ref_keys:
tx = util.convert(key).split(".")[1]
references[tx] = r.get(key)
return util.convert(references)
def get_ref(reference):
"""
Get the trade information for a reference.
:param reference: trade reference
:type reference: string
:return: dict of trade information
:rtype: dict
"""
ref_data = r.hgetall(f"trade.{reference}")
ref_data = util.convert(ref_data)
if "subclass" not in ref_data:
ref_data["subclass"] = "agora"
if not ref_data:
return False
return ref_data
def get_tx(tx):
"""
Get the transaction information for a transaction ID.
:param reference: trade reference
:type reference: string
:return: dict of trade information
:rtype: dict
"""
tx_data = r.hgetall(f"tx.{tx}")
tx_data = util.convert(tx_data)
if not tx_data:
return False
return tx_data
def get_subclass(reference):
obj = r.hget(f"trade.{reference}", "subclass")
subclass = util.convert(obj)
return subclass
def del_ref(reference):
"""
Delete a given reference from the Redis database.
:param reference: trade reference to delete
:type reference: string
"""
tx = ref_to_tx(reference)
r.delete(f"trade.{reference}")
r.delete(f"trade.{tx}.reference")
def cleanup(subclass, references):
"""
Reconcile the internal reference database with a given list of references.
Delete all internal references not present in the list and clean up artifacts.
:param references: list of references to reconcile against
:type references: list
"""
for tx, reference in get_ref_map().items():
if reference not in references:
if get_subclass(reference) == subclass:
log.info(f"Archiving trade reference: {reference} / TX: {tx}")
r.rename(f"trade.{tx}.reference", f"archive.trade.{tx}.reference")
r.rename(f"trade.{reference}", f"archive.trade.{reference}")

View File

@ -13,6 +13,7 @@ from settings import settings
import util import util
from lib.agoradesk_py import AgoraDesk from lib.agoradesk_py import AgoraDesk
from lib.localbitcoins_py import LocalBitcoins from lib.localbitcoins_py import LocalBitcoins
import db
class Local(util.Base): class Local(util.Base):
@ -89,7 +90,7 @@ class Local(util.Base):
if dash is False: if dash is False:
return False return False
for contact_id, contact in dash.items(): for contact_id, contact in dash.items():
reference = self.tx.tx_to_ref(contact_id) reference = db.tx_to_ref(contact_id)
buyer = contact["data"]["buyer"]["username"] buyer = contact["data"]["buyer"]["username"]
amount = contact["data"]["amount"] amount = contact["data"]["amount"]
if self.platform == "agora": if self.platform == "agora":
@ -123,7 +124,7 @@ class Local(util.Base):
if not dash.items(): if not dash.items():
return return
for contact_id, contact in dash.items(): for contact_id, contact in dash.items():
reference = self.tx.tx_to_ref(str(contact_id)) reference = db.tx_to_ref(str(contact_id))
if reference: if reference:
current_trades.append(reference) current_trades.append(reference)
buyer = contact["data"]["buyer"]["username"] buyer = contact["data"]["buyer"]["username"]
@ -170,7 +171,7 @@ class Local(util.Base):
self.last_dash.remove(ref) self.last_dash.remove(ref)
if reference and reference not in current_trades: if reference and reference not in current_trades:
current_trades.append(reference) current_trades.append(reference)
self.tx.cleanup(self.platform, current_trades) db.cleanup(self.platform, current_trades)
def got_recent_messages(self, messages, send_irc=True): def got_recent_messages(self, messages, send_irc=True):
""" """
@ -184,14 +185,14 @@ class Local(util.Base):
if "data" not in messages["response"]: if "data" not in messages["response"]:
self.log.error(f"Data not in messages response: {messages['response']}") self.log.error(f"Data not in messages response: {messages['response']}")
return False return False
open_tx = self.tx.get_ref_map().keys() open_tx = db.get_ref_map().keys()
for message in messages["response"]["data"]["message_list"]: for message in messages["response"]["data"]["message_list"]:
contact_id = message["contact_id"] contact_id = message["contact_id"]
username = message["sender"]["username"] username = message["sender"]["username"]
msg = message["msg"] msg = message["msg"]
if contact_id not in open_tx: if contact_id not in open_tx:
continue continue
reference = self.tx.tx_to_ref(contact_id) reference = db.tx_to_ref(contact_id)
if reference in messages_tmp: if reference in messages_tmp:
messages_tmp[reference].append([username, msg]) messages_tmp[reference].append([username, msg])
else: else:

View File

@ -1,5 +1,7 @@
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
# from twisted.internet.defer import inlineCallbacks
from json import loads from json import loads
from copy import deepcopy from copy import deepcopy

View File

@ -1,5 +1,7 @@
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
# from twisted.internet.defer import inlineCallbacks
from json import loads from json import loads
from copy import deepcopy from copy import deepcopy

View File

@ -30,10 +30,10 @@ class TestTransactions(TestCase):
} }
# Mock redis calls # Mock redis calls
transactions.r.hgetall = self.mock_hgetall transactions.db.r.hgetall = self.mock_hgetall
transactions.r.hmset = self.mock_hmset transactions.db.r.hmset = self.mock_hmset
transactions.r.keys = self.mock_keys transactions.db.r.keys = self.mock_keys
transactions.r.get = self.mock_get transactions.db.r.get = self.mock_get
# Mock some callbacks # Mock some callbacks
self.transactions.irc = MagicMock() self.transactions.irc = MagicMock()

View File

@ -13,7 +13,7 @@ import logging
# Project imports # Project imports
from settings import settings from settings import settings
from db import r import db
import util import util
# TODO: secure ES traffic properly # TODO: secure ES traffic properly
@ -131,7 +131,7 @@ class Transactions(util.Base):
# Split the reference into parts # Split the reference into parts
ref_split = reference.split(" ") ref_split = reference.split(" ")
# Get all existing references # Get all existing references
existing_refs = self.get_refs() existing_refs = db.get_refs()
# Get all parts of the given reference split that match the existing references # Get all parts of the given reference split that match the existing references
stored_trade_reference = set(existing_refs).intersection(set(ref_split)) stored_trade_reference = set(existing_refs).intersection(set(ref_split))
if len(stored_trade_reference) > 1: if len(stored_trade_reference) > 1:
@ -174,7 +174,7 @@ class Transactions(util.Base):
return stored_trade return stored_trade
def normal_lookup(self, stored_trade_reference, reference, currency, amount): def normal_lookup(self, stored_trade_reference, reference, currency, amount):
stored_trade = self.get_ref(stored_trade_reference) stored_trade = db.get_ref(stored_trade_reference)
if not stored_trade: if not stored_trade:
self.log.info(f"No reference in DB for {reference}") self.log.info(f"No reference in DB for {reference}")
self.irc.sendmsg(f"No reference in DB for {reference}") self.irc.sendmsg(f"No reference in DB for {reference}")
@ -227,7 +227,7 @@ class Transactions(util.Base):
:param bank_sender: the sender name from the bank :param bank_sender: the sender name from the bank
""" """
key = f"namemap.{platform}.{platform_buyer}" key = f"namemap.{platform}.{platform_buyer}"
r.sadd(key, bank_sender) db.r.sadd(key, bank_sender)
def get_previous_senders(self, platform, platform_buyer): def get_previous_senders(self, platform, platform_buyer):
""" """
@ -238,7 +238,7 @@ class Transactions(util.Base):
:rtype: set :rtype: set
""" """
key = f"namemap.{platform}.{platform_buyer}" key = f"namemap.{platform}.{platform_buyer}"
senders = r.smembers(key) senders = db.r.smembers(key)
if not senders: if not senders:
return None return None
senders = util.convert(senders) senders = util.convert(senders)
@ -273,10 +273,10 @@ class Transactions(util.Base):
:param tx: the transaction ID :param tx: the transaction ID
:param reference: the trade reference :param reference: the trade reference
""" """
stored_trade = self.get_ref(reference) stored_trade = db.get_ref(reference)
if not stored_trade: if not stored_trade:
return None return None
stored_tx = self.get_tx(tx) stored_tx = db.get_tx(tx)
if not stored_tx: if not stored_tx:
return None return None
bank_sender = stored_tx["sender"] bank_sender = stored_tx["sender"]
@ -292,11 +292,11 @@ class Transactions(util.Base):
Update a trade to point to a given transaction ID. Update a trade to point to a given transaction ID.
Return False if the trade already has a mapped transaction. Return False if the trade already has a mapped transaction.
""" """
existing_tx = r.hget(f"trade.{reference}", "tx") existing_tx = db.r.hget(f"trade.{reference}", "tx")
if existing_tx is None: if existing_tx is None:
return None return None
elif existing_tx == b"": elif existing_tx == b"":
r.hset(f"trade.{reference}", "tx", txid) db.r.hset(f"trade.{reference}", "tx", txid)
return True return True
else: # Already a mapped transaction else: # Already a mapped transaction
return False return False
@ -329,7 +329,7 @@ class Transactions(util.Base):
"currency": currency, "currency": currency,
"sender": sender, "sender": sender,
} }
r.hmset(f"tx.{txid}", to_store) db.r.hmset(f"tx.{txid}", to_store)
self.log.info(f"Transaction processed: {dumps(to_store, indent=2)}") self.log.info(f"Transaction processed: {dumps(to_store, indent=2)}")
self.irc.sendmsg(f"AUTO Incoming transaction on {subclass}: {txid} {amount}{currency} ({reference})") self.irc.sendmsg(f"AUTO Incoming transaction on {subclass}: {txid} {amount}{currency} ({reference})")
@ -390,7 +390,7 @@ class Transactions(util.Base):
self.release_funds(stored_trade["id"], stored_trade["reference"]) self.release_funds(stored_trade["id"], stored_trade["reference"])
def release_funds(self, trade_id, reference): def release_funds(self, trade_id, reference):
stored_trade = self.get_ref(reference) stored_trade = db.get_ref(reference)
platform = stored_trade["subclass"] platform = stored_trade["subclass"]
logmessage = f"All checks passed, releasing funds for {trade_id} {reference}" logmessage = f"All checks passed, releasing funds for {trade_id} {reference}"
self.log.info(logmessage) self.log.info(logmessage)
@ -422,11 +422,11 @@ class Transactions(util.Base):
Map a trade to a transaction and release if no other TX is Map a trade to a transaction and release if no other TX is
mapped to the same trade. mapped to the same trade.
""" """
stored_trade = self.get_ref(reference) stored_trade = db.get_ref(reference)
if not stored_trade: if not stored_trade:
self.log.error(f"Could not get stored trade for {reference}.") self.log.error(f"Could not get stored trade for {reference}.")
return None return None
tx_obj = self.get_tx(tx) tx_obj = db.get_tx(tx)
if not tx_obj: if not tx_obj:
self.log.error(f"Could not get TX for {tx}.") self.log.error(f"Could not get TX for {tx}.")
return None return None
@ -453,10 +453,10 @@ class Transactions(util.Base):
:return: tuple of (platform, trade_id, reference, currency) :return: tuple of (platform, trade_id, reference, currency)
""" """
platform, username = self.get_uid(uid) platform, username = self.get_uid(uid)
refs = self.get_refs() refs = db.get_refs()
matching_trades = [] matching_trades = []
for reference in refs: for reference in refs:
ref_data = self.get_ref(reference) ref_data = db.get_ref(reference)
tx_platform = ref_data["subclass"] tx_platform = ref_data["subclass"]
tx_username = ref_data["buyer"] tx_username = ref_data["buyer"]
trade_id = ref_data["id"] trade_id = ref_data["id"]
@ -553,7 +553,7 @@ class Transactions(util.Base):
""" """
reference = "".join(choices(ascii_uppercase, k=5)) reference = "".join(choices(ascii_uppercase, k=5))
reference = f"PGN-{reference}" reference = f"PGN-{reference}"
existing_ref = r.get(f"trade.{trade_id}.reference") existing_ref = db.r.get(f"trade.{trade_id}.reference")
if not existing_ref: if not existing_ref:
to_store = { to_store = {
"id": trade_id, "id": trade_id,
@ -568,8 +568,8 @@ class Transactions(util.Base):
"subclass": subclass, "subclass": subclass,
} }
self.log.info(f"Storing trade information: {str(to_store)}") self.log.info(f"Storing trade information: {str(to_store)}")
r.hmset(f"trade.{reference}", to_store) db.r.hmset(f"trade.{reference}", to_store)
r.set(f"trade.{trade_id}.reference", reference) db.r.set(f"trade.{trade_id}.reference", reference)
self.irc.sendmsg(f"Generated reference for {trade_id}: {reference}") self.irc.sendmsg(f"Generated reference for {trade_id}: {reference}")
self.ux.notify.notify_new_trade(amount, currency) self.ux.notify.notify_new_trade(amount, currency)
uid = self.create_uid(subclass, buyer) uid = self.create_uid(subclass, buyer)
@ -599,11 +599,11 @@ class Transactions(util.Base):
:return: matching trade object or False :return: matching trade object or False
:rtype: dict or bool :rtype: dict or bool
""" """
refs = self.get_refs() refs = db.get_refs()
matching_refs = [] matching_refs = []
# TODO: use get_ref_map in this function instead of calling get_ref multiple times # TODO: use get_ref_map in this function instead of calling get_ref multiple times
for ref in refs: for ref in refs:
stored_trade = self.get_ref(ref) stored_trade = db.get_ref(ref)
if stored_trade["currency"] == currency and float(stored_trade["amount"]) == float(amount): if stored_trade["currency"] == currency and float(stored_trade["amount"]) == float(amount):
matching_refs.append(stored_trade) matching_refs.append(stored_trade)
if len(matching_refs) != 1: if len(matching_refs) != 1:
@ -611,119 +611,6 @@ class Transactions(util.Base):
return False return False
return matching_refs[0] return matching_refs[0]
def get_refs(self):
"""
Get all reference IDs for trades.
:return: list of trade IDs
:rtype: list
"""
references = []
ref_keys = r.keys("trade.*.reference")
for key in ref_keys:
references.append(r.get(key))
return util.convert(references)
def get_ref_map(self):
"""
Get all reference IDs for trades.
:return: dict of references keyed by TXID
:rtype: dict
"""
references = {}
ref_keys = r.keys("trade.*.reference")
for key in ref_keys:
tx = util.convert(key).split(".")[1]
references[tx] = r.get(key)
return util.convert(references)
def get_ref(self, reference):
"""
Get the trade information for a reference.
:param reference: trade reference
:type reference: string
:return: dict of trade information
:rtype: dict
"""
ref_data = r.hgetall(f"trade.{reference}")
ref_data = util.convert(ref_data)
if "subclass" not in ref_data:
ref_data["subclass"] = "agora"
if not ref_data:
return False
return ref_data
def get_tx(self, tx):
"""
Get the transaction information for a transaction ID.
:param reference: trade reference
:type reference: string
:return: dict of trade information
:rtype: dict
"""
tx_data = r.hgetall(f"tx.{tx}")
tx_data = util.convert(tx_data)
if not tx_data:
return False
return tx_data
def get_subclass(self, reference):
obj = r.hget(f"trade.{reference}", "subclass")
subclass = util.convert(obj)
return subclass
def del_ref(self, reference):
"""
Delete a given reference from the Redis database.
:param reference: trade reference to delete
:type reference: string
"""
tx = self.ref_to_tx(reference)
r.delete(f"trade.{reference}")
r.delete(f"trade.{tx}.reference")
def cleanup(self, subclass, references):
"""
Reconcile the internal reference database with a given list of references.
Delete all internal references not present in the list and clean up artifacts.
:param references: list of references to reconcile against
:type references: list
"""
for tx, reference in self.get_ref_map().items():
if reference not in references:
if self.get_subclass(reference) == subclass:
self.log.info(f"Archiving trade reference: {reference} / TX: {tx}")
r.rename(f"trade.{tx}.reference", f"archive.trade.{tx}.reference")
r.rename(f"trade.{reference}", f"archive.trade.{reference}")
def tx_to_ref(self, tx):
"""
Convert a trade ID to a reference.
:param tx: trade ID
:type tx: string
:return: reference
:rtype: string
"""
refs = self.get_refs()
for reference in refs:
ref_data = util.convert(r.hgetall(f"trade.{reference}"))
if not ref_data:
continue
if ref_data["id"] == tx:
return reference
def ref_to_tx(self, reference):
"""
Convert a reference to a trade ID.
:param reference: trade reference
:type reference: string
:return: trade ID
:rtype: string
"""
ref_data = util.convert(r.hgetall(f"trade.{reference}"))
if not ref_data:
return False
return ref_data["id"]
@inlineCallbacks @inlineCallbacks
def get_total_usd(self): def get_total_usd(self):
""" """