From 479e5b1022b5d955f08b74d06afd633b3fd4691d Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Wed, 8 Mar 2023 12:48:05 +0000 Subject: [PATCH] Implement adding bank links --- app/urls.py | 22 ++ core/__init__.py | 2 - core/clients/__init__.py | 0 core/clients/aggregators/nordigen.py | 99 +++++++++ core/clients/base.py | 206 ++++++++++++++++++ core/lib/schemas/__init__.py | 1 + core/lib/schemas/nordigen_s.py | 77 +++++++ .../0004_aggregator_access_token_expires.py | 18 ++ core/models.py | 12 + .../partials/aggregator-countries.html | 36 +++ .../partials/aggregator-country-banks.html | 38 ++++ core/templates/partials/aggregator-info.html | 75 +++++++ core/templates/partials/aggregator-list.html | 31 +-- core/views/aggregators.py | 177 ++++++++++++++- core/views/base.py | 4 +- requirements.txt | 5 +- 16 files changed, 764 insertions(+), 39 deletions(-) create mode 100644 core/clients/__init__.py create mode 100644 core/clients/aggregators/nordigen.py create mode 100644 core/clients/base.py create mode 100644 core/lib/schemas/__init__.py create mode 100644 core/lib/schemas/nordigen_s.py create mode 100644 core/migrations/0004_aggregator_access_token_expires.py create mode 100644 core/templates/partials/aggregator-countries.html create mode 100644 core/templates/partials/aggregator-country-banks.html create mode 100644 core/templates/partials/aggregator-info.html diff --git a/app/urls.py b/app/urls.py index abd942b..d379637 100644 --- a/app/urls.py +++ b/app/urls.py @@ -59,4 +59,26 @@ urlpatterns = [ aggregators.AggregatorDelete.as_view(), name="aggregator_delete", ), + # Aggregator Requisitions + path( + "aggs//info//", + aggregators.ReqsList.as_view(), + name="reqs", + ), + # Aggregator Account link flow + path( + "aggs//countries//", + aggregators.AggregatorCountriesList.as_view(), + name="aggregator_countries", + ), + path( + "aggs//countries///banks/", + aggregators.AggregatorCountryBanksList.as_view(), + name="aggregator_country_banks", + ), + path( + "aggs//link///", + aggregators.AggregatorLinkBank.as_view(), + name="aggregator_link", + ), ] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) diff --git a/core/__init__.py b/core/__init__.py index 15764c8..3463c6f 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -1,8 +1,6 @@ import os # import stripe -from django.conf import settings - os.environ["DJANGO_ALLOW_ASYNC_UNSAFE"] = "true" # from redis import StrictRedis diff --git a/core/clients/__init__.py b/core/clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/clients/aggregators/nordigen.py b/core/clients/aggregators/nordigen.py new file mode 100644 index 0000000..2dfe375 --- /dev/null +++ b/core/clients/aggregators/nordigen.py @@ -0,0 +1,99 @@ +from datetime import timedelta + +from django.conf import settings +from django.utils import timezone + +from core.clients.base import BaseClient +from core.util import logs + +log = logs.get_logger("nordigen") + + +class NordigenClient(BaseClient): + url = "https://ob.nordigen.com/api/v2" + + async def connect(self): + now = timezone.now() + # Check if access token expires later than now + if self.instance.access_token_expires is not None: + if self.instance.access_token_expires > now: + self.token = self.instance.access_token + return + await self.get_access_token() + + def method_filter(self, method): + new_method = method.replace("/", "_") + return new_method + + async def get_access_token(self): + """ + Get the access token for the Nordigen API. + """ + log.debug(f"Getting new access token for {self.instance}") + data = { + "secret_id": self.instance.secret_id, + "secret_key": self.instance.secret_key, + } + + response = await self.call("token/new", http_method="post", data=data) + print("RESPONSE IN GET ACCESS TOKEN", response) # + access = response["access"] + access_expires = response["access_expires"] + print("ACCESS EXPIRES", access_expires) + now = timezone.now() + # Offset now by access_expires seconds + access_expires = now + timedelta(seconds=access_expires) + print("ACCESS EXPIRES", access_expires) + self.instance.access_token = access + self.instance.access_token_expires = access_expires + self.instance.save() + + self.token = access + + async def get_requisitions(self): + """ + Get a list of active accounts. + """ + response = await self.call("requisitions") + return response["results"] + + async def get_countries(self): + """ + Get a list of countries. + """ + # This function is a stub. + + return ["GB", "SE"] + + async def get_banks(self, country): + """ + Get a list of supported banks for a country. + :param country: country to query + :return: list of institutions + :rtype: list + """ + if not len(country) == 2: + return False + path = f"institutions/?country={country}" + response = await self.call(path, schema="Institutions", append_slash=False) + + return response + + async def build_link(self, institution_id, redirect=None): + """Create a link to access an institution. + :param institution_id: ID of the institution + """ + + data = { + "institution_id": institution_id, + "redirect": settings.URL, + } + if redirect: + data["redirect"] = redirect + response = await self.call( + "requisitions", schema="RequisitionsPost", http_method="post", data=data + ) + print("build_link response", response) + if "link" in response: + return response["link"] + return False diff --git a/core/clients/base.py b/core/clients/base.py new file mode 100644 index 0000000..55988a6 --- /dev/null +++ b/core/clients/base.py @@ -0,0 +1,206 @@ +from abc import ABC, abstractmethod + +import aiohttp +import orjson +from glom import glom +from pydantic.error_wrappers import ValidationError + +from core.lib import schemas +from core.util import logs + +# Return error if the schema for the message type is not found +STRICT_VALIDATION = False + +# Raise exception if the conversion schema is not found +STRICT_CONVERSION = False + +# TODO: Set them to True when all message types are implemented + +log = logs.get_logger("clients") + + +class NoSchema(Exception): + """ + Raised when: + - The schema for the message type is not found + - The conversion schema is not found + - There is no schema library for the client + """ + + pass + + +class NoSuchMethod(Exception): + """ + Client library has no such method. + """ + + pass + + +class GenericAPIError(Exception): + """ + Generic API error. + """ + + pass + + +def is_camel_case(s): + return s != s.lower() and s != s.upper() and "_" not in s + + +def snake_to_camel(word): + if is_camel_case(word): + return word + return "".join(x.capitalize() or "_" for x in word.split("_")) + + +DEFAULT_HEADERS = { + "accept": "application/json", + "Content-Type": "application/json", +} + + +class BaseClient(ABC): + token = None + + async def __new__(cls, *a, **kw): + instance = super().__new__(cls) + await instance.__init__(*a, **kw) + return instance + + async def __init__(self, instance): + """ + Initialise the client. + :param instance: the database object, e.g. Aggregator + """ + name = self.__class__.__name__ + self.name = name.replace("Client", "").lower() + self.instance = instance + self.client = None + + await self.connect() + + @abstractmethod + async def connect(self): + pass + + @property + def schema(self): + """ + Get the schema library for the client. + """ + # Does the schemas library have a library for this client name? + if hasattr(schemas, f"{self.name}_s"): + schema_instance = getattr(schemas, f"{self.name}_s") + else: + log.error(f"No schema library for {self.name}") + raise Exception(f"No schema library for client {self.name}") + + return schema_instance + + def get_schema(self, method, convert=False): + if isinstance(method, str): + to_camel = snake_to_camel(method) + else: + to_camel = snake_to_camel(method.__class__.__name__) + if convert: + to_camel = f"{to_camel}Schema" + + # if hasattr(self.schema, method): + # schema = getattr(self.schema, method) + if hasattr(self.schema, to_camel): + schema = getattr(self.schema, to_camel) + else: + raise NoSchema(f"Could not get schema: {to_camel}") + return schema + + async def call_method(self, method, *args, **kwargs): + """ + Call a method with aiohttp. + """ + if kwargs.get("append_slash", True): + path = f"{self.url}/{method}/" + else: + path = f"{self.url}/{method}" + + http_method = kwargs.get("http_method", "get") + + cast = { + "headers": DEFAULT_HEADERS, + } + + print("TOKEN", self.token) + # Use the token if it's set + if self.token is not None: + cast["headers"]["Authorization"] = f"Bearer {self.token}" + + if "data" in kwargs: + cast["data"] = orjson.dumps(kwargs["data"]) + + # Use the method to send a HTTP request + async with aiohttp.ClientSession() as session: + session_method = getattr(session, http_method) + async with session_method(path, **cast) as response: + response_json = await response.json() + return response_json + + def convert_spec(self, response, method): + """ + Convert an API response to the requested spec. + :raises NoSchema: If the conversion schema is not found + """ + schema = self.get_schema(method, convert=True) + + # Use glom to convert the response to the schema + converted = glom(response, schema) + return converted + + def validate_response(self, response, method): + schema = self.get_schema(method) + # Return a dict of the validated response + try: + response_valid = schema(**response).dict() + except ValidationError as e: + log.error(f"Error validating {method} response: {response}") + log.error(f"Errors: {e}") + raise GenericAPIError("Error validating response") + return response_valid + + def method_filter(self, method): + """ + Return a new method. + """ + return method + + async def call(self, method, *args, **kwargs): + """ + Call the exchange API and validate the response + :raises NoSchema: If the method is not in the schema mapping + :raises ValidationError: If the response cannot be validated + """ + # try: + response = await self.call_method(method, *args, **kwargs) + # except (APIError, V20Error) as e: + # log.error(f"Error calling method {method}: {e}") + # raise GenericAPIError(e) + + if "schema" in kwargs: + method = kwargs["schema"] + else: + method = self.method_filter(method) + try: + response_valid = self.validate_response(response, method) + except NoSchema as e: + log.error(f"{e} - {response}") + response_valid = response + # Convert the response to a format that we can use + try: + response_converted = self.convert_spec(response_valid, method) + except NoSchema as e: + log.error(f"{e} - {response}") + response_converted = response_valid + + # return (True, response_converted) + return response_converted diff --git a/core/lib/schemas/__init__.py b/core/lib/schemas/__init__.py new file mode 100644 index 0000000..940aae0 --- /dev/null +++ b/core/lib/schemas/__init__.py @@ -0,0 +1 @@ +from core.lib.schemas import nordigen_s # noqa diff --git a/core/lib/schemas/nordigen_s.py b/core/lib/schemas/nordigen_s.py new file mode 100644 index 0000000..00ac771 --- /dev/null +++ b/core/lib/schemas/nordigen_s.py @@ -0,0 +1,77 @@ +from pydantic import BaseModel + + +class TokenNew(BaseModel): + access: str + access_expires: int + refresh: str + refresh_expires: int + + +TokenNewSchema = { + "access": "access", + "access_expires": "access_expires", + "refresh": "refresh", + "refresh_expires": "refresh_expires", +} + + +class RequisitionResult(BaseModel): + id: str + created: str + redirect: str + status: str + institution_id: str + agreement: str + reference: str + accounts: list[str] + link: str + ssn: str | None + account_selection: bool + redirect_immediate: bool + + +class Requisitions(BaseModel): + count: int + next: str | None + previous: str | None + results: list[RequisitionResult] + + +RequisitionsSchema = { + "count": "count", + "next": "next", + "previous": "previous", + "results": "results", +} + + +class RequisitionsPost(BaseModel): + id: str + created: str + redirect: str + status: str + institution_id: str + agreement: str + reference: str + accounts: list[str] + link: str + ssn: str | None + account_selection: bool + redirect_immediate: bool + + +RequisitionsPostSchema = { + "id": "id", + "created": "created", + "redirect": "redirect", + "status": "status", + "institution_id": "institution_id", + "agreement": "agreement", + "reference": "reference", + "accounts": "accounts", + "link": "link", + "ssn": "ssn", + "account_selection": "account_selection", + "redirect_immediate": "redirect_immediate", +} diff --git a/core/migrations/0004_aggregator_access_token_expires.py b/core/migrations/0004_aggregator_access_token_expires.py new file mode 100644 index 0000000..64c8a0a --- /dev/null +++ b/core/migrations/0004_aggregator_access_token_expires.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.7 on 2023-03-08 10:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0003_aggregator_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='aggregator', + name='access_token_expires', + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/core/models.py b/core/models.py index 29e587e..a3db6de 100644 --- a/core/models.py +++ b/core/models.py @@ -42,6 +42,18 @@ class Aggregator(models.Model): secret_id = models.CharField(max_length=1024, null=True, blank=True) secret_key = models.CharField(max_length=1024, null=True, blank=True) access_token = models.CharField(max_length=1024, null=True, blank=True) + access_token_expires = models.DateTimeField(null=True, blank=True) poll_interval = models.IntegerField(default=10) enabled = models.BooleanField(default=True) + + def __str__(self): + return f"Aggregator ({self.service}) for {self.user}" + + @classmethod + def get_by_id(cls, obj_id, user): + return cls.objects.get(id=obj_id, user=user) + + @property + def client(self): + pass diff --git a/core/templates/partials/aggregator-countries.html b/core/templates/partials/aggregator-countries.html new file mode 100644 index 0000000..142594b --- /dev/null +++ b/core/templates/partials/aggregator-countries.html @@ -0,0 +1,36 @@ +{% include 'mixins/partials/notify.html' %} + + + + + + {% for item in object_list %} + + + + + {% endfor %} + +
countryactions
{{ item }} +
+ +
+
\ No newline at end of file diff --git a/core/templates/partials/aggregator-country-banks.html b/core/templates/partials/aggregator-country-banks.html new file mode 100644 index 0000000..28617e7 --- /dev/null +++ b/core/templates/partials/aggregator-country-banks.html @@ -0,0 +1,38 @@ +{% include 'mixins/partials/notify.html' %} + + + + + + + {% for item in object_list %} + + + + + + {% endfor %} + +
namelogoactions
{{ item.name }} +
+ +
+
\ No newline at end of file diff --git a/core/templates/partials/aggregator-info.html b/core/templates/partials/aggregator-info.html new file mode 100644 index 0000000..684a500 --- /dev/null +++ b/core/templates/partials/aggregator-info.html @@ -0,0 +1,75 @@ + + + + + + + + + {% for item in object_list %} + + + + + + + + {% endfor %} + +
idcreatedinstitutionaccountsactions
+ + + + + + {{ item.created }}{{ item.institution_id }}{{ item.accounts }} +
+ + {% if type == 'page' %} + + + {% else %} + + {% endif %} +
+
\ No newline at end of file diff --git a/core/templates/partials/aggregator-list.html b/core/templates/partials/aggregator-list.html index 162f9d7..83d6d24 100644 --- a/core/templates/partials/aggregator-list.html +++ b/core/templates/partials/aggregator-list.html @@ -1,8 +1,8 @@ {% load cache %} {% load cachalot cache %} -{% get_last_invalidation 'core.Hook' as last %} +{% get_last_invalidation 'core.Aggregator' as last %} {% include 'mixins/partials/notify.html' %} -{# cache 600 objects_hooks request.user.id object_list type last #} +{# cache 600 objects_aggregators request.user.id object_list type last #} - + diff --git a/core/views/aggregators.py b/core/views/aggregators.py index d514134..f13fccd 100644 --- a/core/views/aggregators.py +++ b/core/views/aggregators.py @@ -1,12 +1,12 @@ +import asyncio + from django.contrib.auth.mixins import LoginRequiredMixin -from mixins.views import ( # ObjectRead, - ObjectCreate, - ObjectDelete, - ObjectList, - ObjectUpdate, -) +from django.http import HttpResponse +from django.views import View +from mixins.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate from two_factor.views.mixins import OTPRequiredMixin +from core.clients.aggregators.nordigen import NordigenClient from core.forms import AggregatorForm from core.models import Aggregator from core.util import logs @@ -14,6 +14,171 @@ from core.util import logs log = logs.get_logger(__name__) +def synchronize_async_helper(to_await): + async_response = [] + + async def run_and_capture_result(): + r = await to_await + async_response.append(r) + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + coroutine = run_and_capture_result() + loop.run_until_complete(coroutine) + return async_response[0] + + +class ReqsList(LoginRequiredMixin, OTPRequiredMixin, ObjectList): + list_template = "partials/aggregator-info.html" + page_title = "Aggregator Info" + + context_object_name_singular = "account link" + context_object_name = "account links" + + list_url_name = "reqs" + list_url_args = ["type", "pk"] + + submit_url_name = "aggregator_countries" + submit_url_args = ["type", "pk"] + + def get_queryset(self, **kwargs): + pk = kwargs.get("pk") + try: + aggregator = Aggregator.get_by_id(pk, self.request.user) + + except Aggregator.DoesNotExist: + message = "Aggregator does not exist" + message_class = "danger" + context = { + "message": message, + "message_class": message_class, + "window_content": self.window_content, + } + return self.render_to_response(context) + + self.page_title = ( + f"Requisitions for {aggregator.name} ({aggregator.get_service_display()})" + ) + + run = synchronize_async_helper(NordigenClient(aggregator)) + reqs = synchronize_async_helper(run.get_requisitions()) + print("REQS", reqs) + return reqs + + +class AggregatorCountriesList(LoginRequiredMixin, OTPRequiredMixin, ObjectList): + list_template = "partials/aggregator-countries.html" + page_title = "List of countries" + + list_url_name = "aggregator_countries" + list_url_args = ["type", "pk"] + + context_object_name_singular = "country" + context_object_name = "countries" + + def get_context_data(self): + context = super().get_context_data() + context["pk"] = self.kwargs.get("pk") + return context + + def get_queryset(self, **kwargs): + pk = kwargs.get("pk") + try: + aggregator = Aggregator.get_by_id(pk, self.request.user) + + except Aggregator.DoesNotExist: + message = "Aggregator does not exist" + message_class = "danger" + context = { + "message": message, + "message_class": message_class, + "window_content": self.window_content, + } + return self.render_to_response(context) + + self.page_title = ( + f"Countries for {aggregator.name} ({aggregator.get_service_display()})" + ) + run = synchronize_async_helper(NordigenClient(aggregator)) + countries = synchronize_async_helper(run.get_countries()) + print("COUNTRIES", countries) + self.extra_args = {"pk": pk} + return countries + + +class AggregatorCountryBanksList(LoginRequiredMixin, OTPRequiredMixin, ObjectList): + list_template = "partials/aggregator-country-banks.html" + page_title = "List of banks" + + list_url_name = "aggregator_country_banks" + list_url_args = ["type", "pk", "country"] + + context_object_name_singular = "bank" + context_object_name = "banks" + + def get_context_data(self): + context = super().get_context_data() + context["pk"] = self.kwargs.get("pk") + context["country"] = self.kwargs.get("country") + return context + + def get_queryset(self, **kwargs): + pk = kwargs.get("pk") + country = kwargs.get("country") + try: + aggregator = Aggregator.get_by_id(pk, self.request.user) + + except Aggregator.DoesNotExist: + message = "Aggregator does not exist" + message_class = "danger" + context = { + "message": message, + "message_class": message_class, + "window_content": self.window_content, + } + return self.render_to_response(context) + + self.page_title = ( + f"Banks for {aggregator.name} in {country} " + f"({aggregator.get_service_display()})" + ) + run = synchronize_async_helper(NordigenClient(aggregator)) + banks = synchronize_async_helper(run.get_banks(country)) + print("BANKS", banks) + return banks + + +class AggregatorLinkBank(LoginRequiredMixin, OTPRequiredMixin, View): + def get(self, request, *args, **kwargs): + pk = kwargs.get("pk") + bank = kwargs.get("bank") + try: + aggregator = Aggregator.get_by_id(pk, self.request.user) + + except Aggregator.DoesNotExist: + message = "Aggregator does not exist" + message_class = "danger" + context = { + "message": message, + "message_class": message_class, + "window_content": self.window_content, + } + return self.render_to_response(context) + run = synchronize_async_helper(NordigenClient(aggregator)) + auth_url = synchronize_async_helper(run.build_link(bank)) + + # Redirect to auth url + print("AUTH URL", auth_url) + # Create a blank response + response = HttpResponse() + response["HX-Redirect"] = auth_url + # return redirect(auth_url) + return response + + class AggregatorList(LoginRequiredMixin, OTPRequiredMixin, ObjectList): list_template = "partials/aggregator-list.html" model = Aggregator diff --git a/core/views/base.py b/core/views/base.py index c0e102b..8fd01a6 100644 --- a/core/views/base.py +++ b/core/views/base.py @@ -3,8 +3,8 @@ import logging # import stripe from django.conf import settings from django.contrib.auth.mixins import LoginRequiredMixin -from django.shortcuts import redirect, render -from django.urls import reverse, reverse_lazy +from django.shortcuts import render +from django.urls import reverse_lazy from django.views import View from django.views.generic.edit import CreateView diff --git a/requirements.txt b/requirements.txt index e8dfae4..1896360 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,7 @@ forex_python pyOpenSSL Klein ConfigObject - +aiohttp[speedups] +aioredis[hiredis] +elasticsearch[async] +uvloop
{{ item.user }}{{ item.name }}{{ item.name }} {{ item.get_service_display }} {% if item.enabled %} @@ -72,31 +72,6 @@ - {% if type == 'page' %} - - - {% else %} - - {% endif %}