diff --git a/core/clients/aggregator.py b/core/clients/aggregator.py index 3e1c14e..d08daa1 100644 --- a/core/clients/aggregator.py +++ b/core/clients/aggregator.py @@ -50,10 +50,13 @@ class AggregatorClient(ABC): self.instance.currencies = currencies self.instance.save() - async def process_transactions(self, account_id, transactions): + async def process_transactions(self, account_id, transactions, req): if not transactions: return False + if not req: + return False + platforms = self.instance.platforms for transaction in transactions: transaction_id = transaction["transaction_id"] @@ -71,6 +74,7 @@ class AggregatorClient(ABC): "note": transaction["reference"], } tx_obj = self.instance.add_transaction( + req, account_id, tx_cast, ) diff --git a/core/clients/aggregators/nordigen.py b/core/clients/aggregators/nordigen.py index 0f8c147..1a90c3d 100644 --- a/core/clients/aggregators/nordigen.py +++ b/core/clients/aggregators/nordigen.py @@ -294,7 +294,9 @@ class NordigenClient(BaseClient, AggregatorClient): else: raise Exception(f"No way to get reference: {transaction}") - async def get_transactions(self, account_id, process=False, pending=False): + async def get_transactions( + self, account_id, req=None, process=False, pending=False + ): """ Get all transactions for an account. :param account_id: account to fetch transactions for @@ -307,7 +309,7 @@ class NordigenClient(BaseClient, AggregatorClient): self.normalise_transactions(parsed, state="booked") if process: - await self.process_transactions(account_id, parsed) + await self.process_transactions(account_id, parsed, req=req) if pending: parsed_pending = response["pending"] self.normalise_transactions(parsed_pending, state="pending") diff --git a/core/clients/platform.py b/core/clients/platform.py index 604bfe4..81de66f 100644 --- a/core/clients/platform.py +++ b/core/clients/platform.py @@ -750,6 +750,53 @@ class LocalPlatformClient(ABC): return True + async def reset_trade_tx(self, stored_trade, tx_obj): + """ + Remove a trade to point to a given transaction ID. + """ + if not tx_obj.reconciled: + return False + + if tx_obj not in stored_trade.linked.all(): + return False + + stored_trade.linked.remove(tx_obj) + stored_trade.save() + + tx_obj.reconciled = False + tx_obj.save() + + return True + + async def successful_release(self, trade, transaction): + """ + Called when a trade has been successfully released. + Increment the platform and requisition throughput by the trade value. + Currently only XMR is supported. + """ + if trade.asset != "XMR": + raise NotImplementedError("Only XMR is supported at the moment.") + + # Increment the platform throughput + trade.platform.throughput += trade.amount_crypto + + # Increment the requisition throughput + if transaction.requisition is not None: + transaction.requisition.throughput += trade.amount_crypto + + async def successful_withdrawal(self): + platforms = self.instance.platforms.all() + aggregators = self.instance.aggregators.all() + + for platform in platforms: + platform.throughput = 0 + platform.save() + + for aggregator in aggregators: + for requisition in aggregator.requisitions.all(): + requisition.throughput = 0 + requisition.save() + async def release_map_trade(self, stored_trade, tx_obj): """ Map a trade to a transaction and release if no other TX is @@ -761,9 +808,15 @@ class LocalPlatformClient(ABC): is_updated = await self.update_trade_tx(stored_trade, tx_obj) if is_updated is True: # We mapped the trade successfully - await self.release_trade_escrow(trade_id, stored_trade.reference) await antifraud.add_bank_sender(platform_buyer, bank_sender) - return True + released = await self.release_trade_escrow(trade_id, stored_trade.reference) + if not released: + # We failed to release the funds + # Set the TX back to not reconciled, so we can try this again + await self.reset_trade_tx(stored_trade, tx_obj) + return False + await self.successful_release(stored_trade, tx_obj) + return released else: # Already mapped log.error( diff --git a/core/clients/platforms/agora.py b/core/clients/platforms/agora.py index 26c17ce..3e4596f 100644 --- a/core/clients/platforms/agora.py +++ b/core/clients/platforms/agora.py @@ -119,3 +119,5 @@ class AgoraClient(LocalPlatformClient, BaseClient): # self.irc.sendmsg(f"Withdrawal: {rtrn1['success']} | {rtrn2['success']}") # self.ux.notify.notify_withdrawal(half_rounded) + + # await self.successful_withdrawal() diff --git a/core/management/commands/scheduling.py b/core/management/commands/scheduling.py index c235f8f..a5783e2 100644 --- a/core/management/commands/scheduling.py +++ b/core/management/commands/scheduling.py @@ -27,7 +27,10 @@ async def aggregator_job(): for bank, accounts in aggregator.account_info.items(): for account in accounts: account_id = account["account_id"] - task = instance.get_transactions(account_id, process=True) + requisition_id = account["requisition_id"] + task = instance.get_transactions( + account_id, req=requisition_id, process=True + ) fetch_tasks.append(task) await asyncio.gather(*fetch_tasks) else: diff --git a/core/migrations/0034_transaction_requisition.py b/core/migrations/0034_transaction_requisition.py new file mode 100644 index 0000000..726771e --- /dev/null +++ b/core/migrations/0034_transaction_requisition.py @@ -0,0 +1,19 @@ +# Generated by Django 4.1.7 on 2023-03-20 13:35 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0033_platform_throughput'), + ] + + operations = [ + migrations.AddField( + model_name='transaction', + name='requisition', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='core.requisition'), + ), + ] diff --git a/core/models.py b/core/models.py index 181f229..7dbcccc 100644 --- a/core/models.py +++ b/core/models.py @@ -151,17 +151,17 @@ class Aggregator(models.Model): Then, join them all together. """ return Platform.objects.filter(link_group=self.link_group) - # platforms = [] - # linkgroups = LinkGroup.objects.filter( - # aggregators=self, - # enabled=True, - # ) - # for link in linkgroups: - # for platform in link.platforms.all(): - # if platform not in platforms: - # platforms.append(platform) - # return platforms + @property + def requisitions(self): + """ + Get requisitions for this aggregator. + Do this by looking up LinkGroups with the aggregator. + Then, join them all together. + """ + return Requisition.objects.filter( + aggregator=self, + ) @classmethod def get_currencies_for_platform(cls, platform): @@ -187,7 +187,12 @@ class Aggregator(models.Model): account_info[bank].append(account) return account_info - def add_transaction(self, account_id, tx_data): + def add_transaction(self, requisition_id, account_id, tx_data): + requisition = Requisition.objects.filter( + aggregator=self, requisition_id=requisition_id + ).first() + if requisition: + tx_data["requisition"] = requisition return Transaction.objects.create( aggregator=self, account_id=account_id, @@ -426,22 +431,23 @@ class Platform(models.Model): Do this by looking up LinkGroups with the platform. Then, join them all together. """ - # aggregators = [] - # linkgroups = LinkGroup.objects.filter( - # platforms=self, - # enabled=True, - # ) - # for link in linkgroups: - # for aggregator in link.aggregators.all(): - # if aggregator not in aggregators: - # aggregators.append(aggregator) - - # return aggregators return Aggregator.objects.filter( link_group=self.link_group, ) + @property + def platforms(self): + """ + Get all platforms in this link group. + Do this by looking up LinkGroups with the platform. + Then, join them all together. + """ + + return Platform.objects.filter( + link_group=self.link_group, + ) + def get_requisition(self, aggregator_id, requisition_id): """ Get a Requisition object with the provided values. @@ -524,6 +530,9 @@ class Transaction(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) aggregator = models.ForeignKey(Aggregator, on_delete=models.CASCADE) + requisition = models.ForeignKey( + "core.Requisition", null=True, on_delete=models.CASCADE + ) account_id = models.CharField(max_length=255) transaction_id = models.CharField(max_length=255) diff --git a/core/templates/partials/linkgroup-info.html b/core/templates/partials/linkgroup-info.html index 29db554..e39ce99 100644 --- a/core/templates/partials/linkgroup-info.html +++ b/core/templates/partials/linkgroup-info.html @@ -13,6 +13,7 @@