Integrate Lago with Stripe

master
Mark Veidemanis 1 year ago
parent cde1392e68
commit 9d37e2bfb8
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -14,6 +14,7 @@ CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",")
# Stripe # Stripe
BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues
STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues
STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "") STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "")
STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "") STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "")

@ -18,7 +18,6 @@ from django.conf.urls.static import static
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.views import LogoutView from django.contrib.auth.views import LogoutView
from django.urls import include, path from django.urls import include, path
from django.views.generic import TemplateView
from two_factor.urls import urlpatterns as tf_urls from two_factor.urls import urlpatterns as tf_urls
from core.views import ( from core.views import (
@ -45,7 +44,7 @@ urlpatterns = [
path("__debug__/", include("debug_toolbar.urls")), path("__debug__/", include("debug_toolbar.urls")),
path("", base.Home.as_view(), name="home"), path("", base.Home.as_view(), name="home"),
# path("callback", Callback.as_view(), name="callback"), # path("callback", Callback.as_view(), name="callback"),
# path("billing/", base.Billing.as_view(), name="billing"), path("billing/", base.Billing.as_view(), name="billing"),
# path("order/<str:plan_name>/", base.Order.as_view(), name="order"), # path("order/<str:plan_name>/", base.Order.as_view(), name="order"),
# path( # path(
# "cancel_subscription/<str:plan_name>/", # "cancel_subscription/<str:plan_name>/",
@ -56,7 +55,7 @@ urlpatterns = [
# "success/", TemplateView.as_view(template_name="success.html"), name="success" # "success/", TemplateView.as_view(template_name="success.html"), name="success"
# ), # ),
# path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"), # path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"),
# path("portal", base.Portal.as_view(), name="portal"), path("portal", base.Portal.as_view(), name="portal"),
path("sapp/", admin.site.urls), path("sapp/", admin.site.urls),
# 2FA login urls # 2FA login urls
path("", include(tf_urls)), path("", include(tf_urls)),

@ -32,7 +32,7 @@ class CustomUserAdmin(UserAdmin):
*UserAdmin.fieldsets, *UserAdmin.fieldsets,
( (
"Billing information", "Billing information",
{"fields": ("billing_provider_id",)}, {"fields": ("billing_provider_id", "customer_id", "stripe_id")},
), ),
# ( # (
# "Payment information", # "Payment information",

@ -1,11 +1,90 @@
import stripe
from django.conf import settings from django.conf import settings
from lago_python_client import Client from lago_python_client import Client
from lago_python_client.models import ( from lago_python_client.clients.base_client import LagoApiError
Charge, from lago_python_client.models import Customer, CustomerBillingConfiguration
Charges,
Customer,
CustomerBillingConfiguration,
Plan,
)
client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL) client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL)
def expand_name(first_name, last_name):
"""
Convert two name variables into one.
Last name without a first name is ignored.
"""
name = None
if first_name:
name = first_name
# We only want to put the last name if we have a first name
if last_name:
name += f" {last_name}"
return name
def get_or_create(email, first_name, last_name):
"""
Get a customer ID from Stripe if one with the given email exists.
Create a customer if one does not.
Raise an exception if two or more customers matching the given email exist.
"""
# Let's see if we're just missing the ID
matching_customers = stripe.Customer.list(email=email, limit=2)
if len(matching_customers) == 2:
# Something is horribly wrong
raise Exception(f"Two customers found for email {email}")
elif len(matching_customers) == 1:
# We found a customer. Let's copy the ID
customer = matching_customers["data"][0]
customer_id = customer["id"]
return customer_id
else:
# We didn't find anything. Create the customer
# Create a name, since we have 2 variables which could be null
name = expand_name(first_name, last_name)
cast = {"email": email}
if name:
cast["name"] = name
customer = stripe.Customer.create(**cast)
return customer.id
def create_or_update_customer(user):
try:
customer = client.customers().find(str(user.customer_id))
except LagoApiError:
customer = None
if not customer:
customer = Customer(
external_id=str(user.customer_id),
name=f"{user.first_name} {user.last_name}",
)
customer.external_id = str(user.customer_id)
customer.email = user.email
customer.name = f"{user.first_name} {user.last_name}"
customer.billing_configuration = CustomerBillingConfiguration(
payment_provider="stripe",
provider_customer_id=str(user.stripe_id),
)
try:
created = client.customers().create(customer)
except LagoApiError as e:
print(e.response)
lago_id = created.lago_id
return lago_id
def update_customer_fields(user):
"""
Update the customer fields in Stripe.
"""
stripe.Customer.modify(user.stripe_id, email=user.email)
name = expand_name(user.first_name, user.last_name)
stripe.Customer.modify(user.stripe_id, name=name)

@ -1,65 +0,0 @@
# import logging
# import stripe
# logger = logging.getLogger(__name__)
# def expand_name(first_name, last_name):
# """
# Convert two name variables into one.
# Last name without a first name is ignored.
# """
# name = None
# if first_name:
# name = first_name
# # We only want to put the last name if we have a first name
# if last_name:
# name += f" {last_name}"
# return name
# def get_or_create(email, first_name, last_name):
# """
# Get a customer ID from Stripe if one with the given email exists.
# Create a customer if one does not.
# Raise an exception if two or more customers matching the given email exist.
# """
# # Let's see if we're just missing the ID
# matching_customers = stripe.Customer.list(email=email, limit=2)
# if len(matching_customers) == 2:
# # Something is horribly wrong
# logger.error(f"Two customers found for email {email}")
# raise Exception(f"Two customers found for email {email}")
# elif len(matching_customers) == 1:
# # We found a customer. Let's copy the ID
# customer = matching_customers["data"][0]
# customer_id = customer["id"]
# return customer_id
# else:
# # We didn't find anything. Create the customer
# # Create a name, since we have 2 variables which could be null
# name = expand_name(first_name, last_name)
# cast = {"email": email}
# if name:
# cast["name"] = name
# customer = stripe.Customer.create(**cast)
# logger.info(f"Created new Stripe customer {customer.id} with email {email}")
# return customer.id
# def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None):
# """
# Update the customer fields in Stripe.
# """
# if email:
# stripe.Customer.modify(stripe_id, email=email)
# logger.info(f"Modified Stripe customer {stripe_id} to have email {email}")
# if first_name or last_name:
# name = expand_name(first_name, last_name)
# stripe.Customer.modify(stripe_id, name=name)
# logger.info(f"Modified Stripe customer {stripe_id} to have email {name}")

@ -1,9 +1,10 @@
# Generated by Django 4.1.7 on 2023-02-24 13:18 # Generated by Django 4.1.7 on 2023-02-24 13:18
import uuid
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration): class Migration(migrations.Migration):

@ -0,0 +1,20 @@
# Generated by Django 4.1.7 on 2023-02-24 13:21
import uuid
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0002_account_activemanagementpolicy_assetgroup_hook_and_more'),
]
operations = [
migrations.AddField(
model_name='user',
name='customer_id',
field=models.UUIDField(blank=True, default=uuid.uuid4, null=True),
),
]

@ -0,0 +1,18 @@
# Generated by Django 4.1.7 on 2023-02-24 16:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0003_user_customer_id'),
]
operations = [
migrations.AddField(
model_name='user',
name='stripe_id',
field=models.CharField(blank=True, max_length=255, null=True),
),
]

@ -1,8 +1,8 @@
import uuid import uuid
from datetime import timedelta from datetime import timedelta
# import stripe import stripe
# from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.db import models from django.db import models
@ -11,6 +11,7 @@ from core.exchanges.fake import FakeExchange
from core.exchanges.oanda import OANDAExchange from core.exchanges.oanda import OANDAExchange
# from core.lib.customers import get_or_create, update_customer_fields # from core.lib.customers import get_or_create, update_customer_fields
from core.lib import billing
from core.util import logs from core.util import logs
log = logs.get_logger(__name__) log = logs.get_logger(__name__)
@ -93,21 +94,36 @@ ADJUST_CLOSE_NOTIFY_CHOICES = (
class User(AbstractUser): class User(AbstractUser):
# Stripe customer ID # Stripe customer ID
# unique_id = models.UUIDField( stripe_id = models.CharField(max_length=255, null=True, blank=True)
# default=uuid.uuid4, customer_id = models.UUIDField(default=uuid.uuid4, null=True, blank=True)
# )
# stripe_id = models.CharField(max_length=255, null=True, blank=True)
billing_provider_id = models.CharField(max_length=255, null=True, blank=True) billing_provider_id = models.CharField(max_length=255, null=True, blank=True)
# last_payment = models.DateTimeField(null=True, blank=True) # last_payment = models.DateTimeField(null=True, blank=True)
# plans = models.ManyToManyField(Plan, blank=True) # plans = models.ManyToManyField(Plan, blank=True)
email = models.EmailField(unique=True) email = models.EmailField(unique=True)
# def __init__(self, *args, **kwargs): def delete(self, *args, **kwargs):
# super().__init__(*args, **kwargs) if settings.BILLING_ENABLED:
# self._original = self if self.stripe_id:
stripe.Customer.delete(self.stripe_id)
log.info(f"Deleted Stripe customer {self.stripe_id}")
super().delete(*args, **kwargs)
# Override save to update attributes in Lago # Override save to update attributes in Lago
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if self.customer_id is None:
self.customer_id = uuid.uuid4()
if settings.BILLING_ENABLED:
if not self.stripe_id: # stripe ID not stored
self.stripe_id = billing.get_or_create(
self.email, self.first_name, self.last_name
)
billing_id = billing.create_or_update_customer(self)
self.billing_provider_id = billing_id
billing.update_customer_fields(self)
super().save(*args, **kwargs) super().save(*args, **kwargs)
def get_notification_settings(self): def get_notification_settings(self):

@ -288,7 +288,7 @@
{% endif %} {% endif %}
{% if settings.BILLING_ENABLED %} {% if settings.BILLING_ENABLED %}
{% if user.is_authenticated %} {% if user.is_authenticated %}
<a class="navbar-item" href="{# url 'billing' #}"> <a class="navbar-item" href="{% url 'billing' %}">
Billing Billing
</a> </a>
{% endif %} {% endif %}

@ -10,20 +10,6 @@
</span> </span>
<span class="tag">{{ user.first_name }} {{ user.last_name }}</span> <span class="tag">{{ user.first_name }} {{ user.last_name }}</span>
</a> </a>
<a class="panel-block">
<span class="panel-icon">
<i class="fas fa-binary" aria-hidden="true"></i>
</span>
{% for plan in user.plans.all %}
<span class="tag">{{ plan.name }}</span>
{% endfor %}
</a>
<a class="panel-block">
<span class="panel-icon">
<i class="fas fa-credit-card" aria-hidden="true"></i>
</span>
<span class="tag">{{ user.last_payment }}</span>
</a>
<a class="panel-block" href="{% url 'portal' %}"> <a class="panel-block" href="{% url 'portal' %}">
<span class="panel-icon"> <span class="panel-icon">
<i class="fa-brands fa-stripe-s" aria-hidden="true"></i> <i class="fa-brands fa-stripe-s" aria-hidden="true"></i>

@ -413,8 +413,6 @@ class ActiveManagementLiveTestCase(ElasticMock, StrategyMixin, LiveBase, TestCas
}, },
] ]
} }
print("ACTIONS", self.ams.actions)
print("EXP", expected)
self.assertEqual(self.ams.actions, expected) self.assertEqual(self.ams.actions, expected)

@ -1,10 +1,8 @@
import logging import logging
import stripe import stripe
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import JsonResponse
from django.shortcuts import redirect, render from django.shortcuts import redirect, render
from django.urls import reverse, reverse_lazy from django.urls import reverse, reverse_lazy
from django.views import View from django.views import View
@ -14,7 +12,7 @@ from core.forms import NewUserForm
# from core.lib.products import assemble_plan_map # from core.lib.products import assemble_plan_map
# from core.models import Plan, Session # from core.models import Plan, Session
from core.lib import billing # from core.lib import billing
from core.lib.notify import raw_sendmsg from core.lib.notify import raw_sendmsg
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,16 +27,13 @@ class Home(LoginRequiredMixin, View):
return render(request, self.template_name) return render(request, self.template_name)
# class Billing(LoginRequiredMixin, View): class Billing(LoginRequiredMixin, View):
# template_name = "billing.html" template_name = "billing.html"
# async def get(self, request): def get(self, request):
# if not settings.STRIPE_ENABLED: if not settings.BILLING_ENABLED:
# return redirect(reverse("home")) return redirect(reverse("home"))
# plans = await sync_to_async(list)(Plan.objects.all()) return render(request, self.template_name)
# user_plans = await sync_to_async(list)(request.user.plans.all())
# context = {"plans": plans, "user_plans": user_plans}
# return render(request, self.template_name, context)
# class Order(LoginRequiredMixin, View): # class Order(LoginRequiredMixin, View):
@ -110,12 +105,13 @@ class Signup(CreateView):
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
# class Portal(LoginRequiredMixin, View): class Portal(LoginRequiredMixin, View):
# async def get(self, request): def get(self, request):
# if not settings.STRIPE_ENABLED: if not settings.BILLING_ENABLED:
# return redirect(reverse("home")) return redirect(reverse("home"))
# session = stripe.billing_portal.Session.create(
# customer=request.user.stripe_id, session = stripe.billing_portal.Session.create(
# return_url=request.build_absolute_uri(reverse("billing")), customer=request.user.stripe_id,
# ) return_url=request.build_absolute_uri(reverse("billing")),
# return redirect(session.url) )
return redirect(session.url)

Loading…
Cancel
Save