From 9d37e2bfb8584803069d271b6543574ab575a37e Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Fri, 24 Feb 2023 07:20:51 +0000 Subject: [PATCH] Integrate Lago with Stripe --- app/local_settings.py | 1 + app/urls.py | 5 +- core/admin.py | 2 +- core/lib/billing.py | 93 +++++++++++++++++-- core/lib/customers.py | 65 ------------- ...nagementpolicy_assetgroup_hook_and_more.py | 5 +- core/migrations/0003_user_customer_id.py | 20 ++++ core/migrations/0004_user_stripe_id.py | 18 ++++ core/models.py | 34 +++++-- core/templates/base.html | 2 +- core/templates/billing.html | 14 --- core/tests/trading/test_live.py | 2 - core/views/base.py | 38 ++++---- 13 files changed, 174 insertions(+), 125 deletions(-) delete mode 100644 core/lib/customers.py create mode 100644 core/migrations/0003_user_customer_id.py create mode 100644 core/migrations/0004_user_stripe_id.py diff --git a/app/local_settings.py b/app/local_settings.py index b59ab70..7c12d13 100644 --- a/app/local_settings.py +++ b/app/local_settings.py @@ -14,6 +14,7 @@ CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",") # Stripe BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues + STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "") STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "") diff --git a/app/urls.py b/app/urls.py index 0198f96..5a9c36e 100644 --- a/app/urls.py +++ b/app/urls.py @@ -18,7 +18,6 @@ from django.conf.urls.static import static from django.contrib import admin from django.contrib.auth.views import LogoutView from django.urls import include, path -from django.views.generic import TemplateView from two_factor.urls import urlpatterns as tf_urls from core.views import ( @@ -45,7 +44,7 @@ urlpatterns = [ path("__debug__/", include("debug_toolbar.urls")), path("", base.Home.as_view(), name="home"), # path("callback", Callback.as_view(), name="callback"), - # path("billing/", base.Billing.as_view(), name="billing"), + path("billing/", base.Billing.as_view(), name="billing"), # path("order//", base.Order.as_view(), name="order"), # path( # "cancel_subscription//", @@ -56,7 +55,7 @@ urlpatterns = [ # "success/", TemplateView.as_view(template_name="success.html"), name="success" # ), # path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"), - # path("portal", base.Portal.as_view(), name="portal"), + path("portal", base.Portal.as_view(), name="portal"), path("sapp/", admin.site.urls), # 2FA login urls path("", include(tf_urls)), diff --git a/core/admin.py b/core/admin.py index 60f5bf0..c43ca4e 100644 --- a/core/admin.py +++ b/core/admin.py @@ -32,7 +32,7 @@ class CustomUserAdmin(UserAdmin): *UserAdmin.fieldsets, ( "Billing information", - {"fields": ("billing_provider_id",)}, + {"fields": ("billing_provider_id", "customer_id", "stripe_id")}, ), # ( # "Payment information", diff --git a/core/lib/billing.py b/core/lib/billing.py index f4a0dbe..8d56e3e 100644 --- a/core/lib/billing.py +++ b/core/lib/billing.py @@ -1,11 +1,90 @@ +import stripe from django.conf import settings from lago_python_client import Client -from lago_python_client.models import ( - Charge, - Charges, - Customer, - CustomerBillingConfiguration, - Plan, -) +from lago_python_client.clients.base_client import LagoApiError +from lago_python_client.models import Customer, CustomerBillingConfiguration client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL) + + +def expand_name(first_name, last_name): + """ + Convert two name variables into one. + Last name without a first name is ignored. + """ + name = None + if first_name: + name = first_name + # We only want to put the last name if we have a first name + if last_name: + name += f" {last_name}" + return name + + +def get_or_create(email, first_name, last_name): + """ + Get a customer ID from Stripe if one with the given email exists. + Create a customer if one does not. + Raise an exception if two or more customers matching the given email exist. + """ + # Let's see if we're just missing the ID + matching_customers = stripe.Customer.list(email=email, limit=2) + if len(matching_customers) == 2: + # Something is horribly wrong + raise Exception(f"Two customers found for email {email}") + + elif len(matching_customers) == 1: + # We found a customer. Let's copy the ID + customer = matching_customers["data"][0] + customer_id = customer["id"] + return customer_id + + else: + # We didn't find anything. Create the customer + + # Create a name, since we have 2 variables which could be null + name = expand_name(first_name, last_name) + cast = {"email": email} + if name: + cast["name"] = name + customer = stripe.Customer.create(**cast) + + return customer.id + + +def create_or_update_customer(user): + try: + customer = client.customers().find(str(user.customer_id)) + except LagoApiError: + customer = None + if not customer: + customer = Customer( + external_id=str(user.customer_id), + name=f"{user.first_name} {user.last_name}", + ) + + customer.external_id = str(user.customer_id) + customer.email = user.email + customer.name = f"{user.first_name} {user.last_name}" + customer.billing_configuration = CustomerBillingConfiguration( + payment_provider="stripe", + provider_customer_id=str(user.stripe_id), + ) + + try: + created = client.customers().create(customer) + except LagoApiError as e: + print(e.response) + + lago_id = created.lago_id + + return lago_id + + +def update_customer_fields(user): + """ + Update the customer fields in Stripe. + """ + stripe.Customer.modify(user.stripe_id, email=user.email) + name = expand_name(user.first_name, user.last_name) + stripe.Customer.modify(user.stripe_id, name=name) diff --git a/core/lib/customers.py b/core/lib/customers.py deleted file mode 100644 index 9e8dfee..0000000 --- a/core/lib/customers.py +++ /dev/null @@ -1,65 +0,0 @@ -# import logging - -# import stripe - -# logger = logging.getLogger(__name__) - - -# def expand_name(first_name, last_name): -# """ -# Convert two name variables into one. -# Last name without a first name is ignored. -# """ -# name = None -# if first_name: -# name = first_name -# # We only want to put the last name if we have a first name -# if last_name: -# name += f" {last_name}" -# return name - - -# def get_or_create(email, first_name, last_name): -# """ -# Get a customer ID from Stripe if one with the given email exists. -# Create a customer if one does not. -# Raise an exception if two or more customers matching the given email exist. -# """ -# # Let's see if we're just missing the ID -# matching_customers = stripe.Customer.list(email=email, limit=2) -# if len(matching_customers) == 2: -# # Something is horribly wrong -# logger.error(f"Two customers found for email {email}") -# raise Exception(f"Two customers found for email {email}") - -# elif len(matching_customers) == 1: -# # We found a customer. Let's copy the ID -# customer = matching_customers["data"][0] -# customer_id = customer["id"] -# return customer_id - -# else: -# # We didn't find anything. Create the customer - -# # Create a name, since we have 2 variables which could be null -# name = expand_name(first_name, last_name) -# cast = {"email": email} -# if name: -# cast["name"] = name -# customer = stripe.Customer.create(**cast) -# logger.info(f"Created new Stripe customer {customer.id} with email {email}") - -# return customer.id - - -# def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None): -# """ -# Update the customer fields in Stripe. -# """ -# if email: -# stripe.Customer.modify(stripe_id, email=email) -# logger.info(f"Modified Stripe customer {stripe_id} to have email {email}") -# if first_name or last_name: -# name = expand_name(first_name, last_name) -# stripe.Customer.modify(stripe_id, name=name) -# logger.info(f"Modified Stripe customer {stripe_id} to have email {name}") diff --git a/core/migrations/0002_account_activemanagementpolicy_assetgroup_hook_and_more.py b/core/migrations/0002_account_activemanagementpolicy_assetgroup_hook_and_more.py index d04ec72..257fdac 100644 --- a/core/migrations/0002_account_activemanagementpolicy_assetgroup_hook_and_more.py +++ b/core/migrations/0002_account_activemanagementpolicy_assetgroup_hook_and_more.py @@ -1,9 +1,10 @@ # Generated by Django 4.1.7 on 2023-02-24 13:18 +import uuid + +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion -import uuid class Migration(migrations.Migration): diff --git a/core/migrations/0003_user_customer_id.py b/core/migrations/0003_user_customer_id.py new file mode 100644 index 0000000..70b3a70 --- /dev/null +++ b/core/migrations/0003_user_customer_id.py @@ -0,0 +1,20 @@ +# Generated by Django 4.1.7 on 2023-02-24 13:21 + +import uuid + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0002_account_activemanagementpolicy_assetgroup_hook_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='customer_id', + field=models.UUIDField(blank=True, default=uuid.uuid4, null=True), + ), + ] diff --git a/core/migrations/0004_user_stripe_id.py b/core/migrations/0004_user_stripe_id.py new file mode 100644 index 0000000..1204719 --- /dev/null +++ b/core/migrations/0004_user_stripe_id.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.7 on 2023-02-24 16:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0003_user_customer_id'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='stripe_id', + field=models.CharField(blank=True, max_length=255, null=True), + ), + ] diff --git a/core/models.py b/core/models.py index 17fdd6c..891aea9 100644 --- a/core/models.py +++ b/core/models.py @@ -1,8 +1,8 @@ import uuid from datetime import timedelta -# import stripe -# from django.conf import settings +import stripe +from django.conf import settings from django.contrib.auth.models import AbstractUser from django.db import models @@ -11,6 +11,7 @@ from core.exchanges.fake import FakeExchange from core.exchanges.oanda import OANDAExchange # from core.lib.customers import get_or_create, update_customer_fields +from core.lib import billing from core.util import logs log = logs.get_logger(__name__) @@ -93,21 +94,36 @@ ADJUST_CLOSE_NOTIFY_CHOICES = ( class User(AbstractUser): # Stripe customer ID - # unique_id = models.UUIDField( - # default=uuid.uuid4, - # ) - # stripe_id = models.CharField(max_length=255, null=True, blank=True) + stripe_id = models.CharField(max_length=255, null=True, blank=True) + customer_id = models.UUIDField(default=uuid.uuid4, null=True, blank=True) billing_provider_id = models.CharField(max_length=255, null=True, blank=True) # last_payment = models.DateTimeField(null=True, blank=True) # plans = models.ManyToManyField(Plan, blank=True) email = models.EmailField(unique=True) - # def __init__(self, *args, **kwargs): - # super().__init__(*args, **kwargs) - # self._original = self + def delete(self, *args, **kwargs): + if settings.BILLING_ENABLED: + if self.stripe_id: + stripe.Customer.delete(self.stripe_id) + log.info(f"Deleted Stripe customer {self.stripe_id}") + super().delete(*args, **kwargs) # Override save to update attributes in Lago def save(self, *args, **kwargs): + if self.customer_id is None: + self.customer_id = uuid.uuid4() + + if settings.BILLING_ENABLED: + if not self.stripe_id: # stripe ID not stored + self.stripe_id = billing.get_or_create( + self.email, self.first_name, self.last_name + ) + + billing_id = billing.create_or_update_customer(self) + self.billing_provider_id = billing_id + + billing.update_customer_fields(self) + super().save(*args, **kwargs) def get_notification_settings(self): diff --git a/core/templates/base.html b/core/templates/base.html index 3a2bbbf..de9009a 100644 --- a/core/templates/base.html +++ b/core/templates/base.html @@ -288,7 +288,7 @@ {% endif %} {% if settings.BILLING_ENABLED %} {% if user.is_authenticated %} - + Billing {% endif %} diff --git a/core/templates/billing.html b/core/templates/billing.html index abd3875..59c91b4 100644 --- a/core/templates/billing.html +++ b/core/templates/billing.html @@ -10,20 +10,6 @@ {{ user.first_name }} {{ user.last_name }} - - - - - {% for plan in user.plans.all %} - {{ plan.name }} - {% endfor %} - - - - - - {{ user.last_payment }} - diff --git a/core/tests/trading/test_live.py b/core/tests/trading/test_live.py index 682ef84..ff95047 100644 --- a/core/tests/trading/test_live.py +++ b/core/tests/trading/test_live.py @@ -413,8 +413,6 @@ class ActiveManagementLiveTestCase(ElasticMock, StrategyMixin, LiveBase, TestCas }, ] } - print("ACTIONS", self.ams.actions) - print("EXP", expected) self.assertEqual(self.ams.actions, expected) diff --git a/core/views/base.py b/core/views/base.py index 0405358..2ba8e39 100644 --- a/core/views/base.py +++ b/core/views/base.py @@ -1,10 +1,8 @@ import logging import stripe -from asgiref.sync import sync_to_async from django.conf import settings from django.contrib.auth.mixins import LoginRequiredMixin -from django.http import JsonResponse from django.shortcuts import redirect, render from django.urls import reverse, reverse_lazy from django.views import View @@ -14,7 +12,7 @@ from core.forms import NewUserForm # from core.lib.products import assemble_plan_map # from core.models import Plan, Session -from core.lib import billing +# from core.lib import billing from core.lib.notify import raw_sendmsg logger = logging.getLogger(__name__) @@ -29,16 +27,13 @@ class Home(LoginRequiredMixin, View): return render(request, self.template_name) -# class Billing(LoginRequiredMixin, View): -# template_name = "billing.html" +class Billing(LoginRequiredMixin, View): + template_name = "billing.html" -# async def get(self, request): -# if not settings.STRIPE_ENABLED: -# return redirect(reverse("home")) -# plans = await sync_to_async(list)(Plan.objects.all()) -# user_plans = await sync_to_async(list)(request.user.plans.all()) -# context = {"plans": plans, "user_plans": user_plans} -# return render(request, self.template_name, context) + def get(self, request): + if not settings.BILLING_ENABLED: + return redirect(reverse("home")) + return render(request, self.template_name) # class Order(LoginRequiredMixin, View): @@ -110,12 +105,13 @@ class Signup(CreateView): return super().get(request, *args, **kwargs) -# class Portal(LoginRequiredMixin, View): -# async def get(self, request): -# if not settings.STRIPE_ENABLED: -# return redirect(reverse("home")) -# session = stripe.billing_portal.Session.create( -# customer=request.user.stripe_id, -# return_url=request.build_absolute_uri(reverse("billing")), -# ) -# return redirect(session.url) +class Portal(LoginRequiredMixin, View): + def get(self, request): + if not settings.BILLING_ENABLED: + return redirect(reverse("home")) + + session = stripe.billing_portal.Session.create( + customer=request.user.stripe_id, + return_url=request.build_absolute_uri(reverse("billing")), + ) + return redirect(session.url)