Integrate Lago with Stripe
This commit is contained in:
parent
cde1392e68
commit
9d37e2bfb8
|
@ -14,6 +14,7 @@ CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",")
|
||||||
|
|
||||||
# Stripe
|
# Stripe
|
||||||
BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues
|
BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues
|
||||||
|
|
||||||
STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues
|
STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues
|
||||||
STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "")
|
STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "")
|
||||||
STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "")
|
STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "")
|
||||||
|
|
|
@ -18,7 +18,6 @@ from django.conf.urls.static import static
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.contrib.auth.views import LogoutView
|
from django.contrib.auth.views import LogoutView
|
||||||
from django.urls import include, path
|
from django.urls import include, path
|
||||||
from django.views.generic import TemplateView
|
|
||||||
from two_factor.urls import urlpatterns as tf_urls
|
from two_factor.urls import urlpatterns as tf_urls
|
||||||
|
|
||||||
from core.views import (
|
from core.views import (
|
||||||
|
@ -45,7 +44,7 @@ urlpatterns = [
|
||||||
path("__debug__/", include("debug_toolbar.urls")),
|
path("__debug__/", include("debug_toolbar.urls")),
|
||||||
path("", base.Home.as_view(), name="home"),
|
path("", base.Home.as_view(), name="home"),
|
||||||
# path("callback", Callback.as_view(), name="callback"),
|
# path("callback", Callback.as_view(), name="callback"),
|
||||||
# path("billing/", base.Billing.as_view(), name="billing"),
|
path("billing/", base.Billing.as_view(), name="billing"),
|
||||||
# path("order/<str:plan_name>/", base.Order.as_view(), name="order"),
|
# path("order/<str:plan_name>/", base.Order.as_view(), name="order"),
|
||||||
# path(
|
# path(
|
||||||
# "cancel_subscription/<str:plan_name>/",
|
# "cancel_subscription/<str:plan_name>/",
|
||||||
|
@ -56,7 +55,7 @@ urlpatterns = [
|
||||||
# "success/", TemplateView.as_view(template_name="success.html"), name="success"
|
# "success/", TemplateView.as_view(template_name="success.html"), name="success"
|
||||||
# ),
|
# ),
|
||||||
# path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"),
|
# path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"),
|
||||||
# path("portal", base.Portal.as_view(), name="portal"),
|
path("portal", base.Portal.as_view(), name="portal"),
|
||||||
path("sapp/", admin.site.urls),
|
path("sapp/", admin.site.urls),
|
||||||
# 2FA login urls
|
# 2FA login urls
|
||||||
path("", include(tf_urls)),
|
path("", include(tf_urls)),
|
||||||
|
|
|
@ -32,7 +32,7 @@ class CustomUserAdmin(UserAdmin):
|
||||||
*UserAdmin.fieldsets,
|
*UserAdmin.fieldsets,
|
||||||
(
|
(
|
||||||
"Billing information",
|
"Billing information",
|
||||||
{"fields": ("billing_provider_id",)},
|
{"fields": ("billing_provider_id", "customer_id", "stripe_id")},
|
||||||
),
|
),
|
||||||
# (
|
# (
|
||||||
# "Payment information",
|
# "Payment information",
|
||||||
|
|
|
@ -1,11 +1,90 @@
|
||||||
|
import stripe
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from lago_python_client import Client
|
from lago_python_client import Client
|
||||||
from lago_python_client.models import (
|
from lago_python_client.clients.base_client import LagoApiError
|
||||||
Charge,
|
from lago_python_client.models import Customer, CustomerBillingConfiguration
|
||||||
Charges,
|
|
||||||
Customer,
|
|
||||||
CustomerBillingConfiguration,
|
|
||||||
Plan,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL)
|
client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_name(first_name, last_name):
|
||||||
|
"""
|
||||||
|
Convert two name variables into one.
|
||||||
|
Last name without a first name is ignored.
|
||||||
|
"""
|
||||||
|
name = None
|
||||||
|
if first_name:
|
||||||
|
name = first_name
|
||||||
|
# We only want to put the last name if we have a first name
|
||||||
|
if last_name:
|
||||||
|
name += f" {last_name}"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create(email, first_name, last_name):
|
||||||
|
"""
|
||||||
|
Get a customer ID from Stripe if one with the given email exists.
|
||||||
|
Create a customer if one does not.
|
||||||
|
Raise an exception if two or more customers matching the given email exist.
|
||||||
|
"""
|
||||||
|
# Let's see if we're just missing the ID
|
||||||
|
matching_customers = stripe.Customer.list(email=email, limit=2)
|
||||||
|
if len(matching_customers) == 2:
|
||||||
|
# Something is horribly wrong
|
||||||
|
raise Exception(f"Two customers found for email {email}")
|
||||||
|
|
||||||
|
elif len(matching_customers) == 1:
|
||||||
|
# We found a customer. Let's copy the ID
|
||||||
|
customer = matching_customers["data"][0]
|
||||||
|
customer_id = customer["id"]
|
||||||
|
return customer_id
|
||||||
|
|
||||||
|
else:
|
||||||
|
# We didn't find anything. Create the customer
|
||||||
|
|
||||||
|
# Create a name, since we have 2 variables which could be null
|
||||||
|
name = expand_name(first_name, last_name)
|
||||||
|
cast = {"email": email}
|
||||||
|
if name:
|
||||||
|
cast["name"] = name
|
||||||
|
customer = stripe.Customer.create(**cast)
|
||||||
|
|
||||||
|
return customer.id
|
||||||
|
|
||||||
|
|
||||||
|
def create_or_update_customer(user):
|
||||||
|
try:
|
||||||
|
customer = client.customers().find(str(user.customer_id))
|
||||||
|
except LagoApiError:
|
||||||
|
customer = None
|
||||||
|
if not customer:
|
||||||
|
customer = Customer(
|
||||||
|
external_id=str(user.customer_id),
|
||||||
|
name=f"{user.first_name} {user.last_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
customer.external_id = str(user.customer_id)
|
||||||
|
customer.email = user.email
|
||||||
|
customer.name = f"{user.first_name} {user.last_name}"
|
||||||
|
customer.billing_configuration = CustomerBillingConfiguration(
|
||||||
|
payment_provider="stripe",
|
||||||
|
provider_customer_id=str(user.stripe_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
created = client.customers().create(customer)
|
||||||
|
except LagoApiError as e:
|
||||||
|
print(e.response)
|
||||||
|
|
||||||
|
lago_id = created.lago_id
|
||||||
|
|
||||||
|
return lago_id
|
||||||
|
|
||||||
|
|
||||||
|
def update_customer_fields(user):
|
||||||
|
"""
|
||||||
|
Update the customer fields in Stripe.
|
||||||
|
"""
|
||||||
|
stripe.Customer.modify(user.stripe_id, email=user.email)
|
||||||
|
name = expand_name(user.first_name, user.last_name)
|
||||||
|
stripe.Customer.modify(user.stripe_id, name=name)
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
# import logging
|
|
||||||
|
|
||||||
# import stripe
|
|
||||||
|
|
||||||
# logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# def expand_name(first_name, last_name):
|
|
||||||
# """
|
|
||||||
# Convert two name variables into one.
|
|
||||||
# Last name without a first name is ignored.
|
|
||||||
# """
|
|
||||||
# name = None
|
|
||||||
# if first_name:
|
|
||||||
# name = first_name
|
|
||||||
# # We only want to put the last name if we have a first name
|
|
||||||
# if last_name:
|
|
||||||
# name += f" {last_name}"
|
|
||||||
# return name
|
|
||||||
|
|
||||||
|
|
||||||
# def get_or_create(email, first_name, last_name):
|
|
||||||
# """
|
|
||||||
# Get a customer ID from Stripe if one with the given email exists.
|
|
||||||
# Create a customer if one does not.
|
|
||||||
# Raise an exception if two or more customers matching the given email exist.
|
|
||||||
# """
|
|
||||||
# # Let's see if we're just missing the ID
|
|
||||||
# matching_customers = stripe.Customer.list(email=email, limit=2)
|
|
||||||
# if len(matching_customers) == 2:
|
|
||||||
# # Something is horribly wrong
|
|
||||||
# logger.error(f"Two customers found for email {email}")
|
|
||||||
# raise Exception(f"Two customers found for email {email}")
|
|
||||||
|
|
||||||
# elif len(matching_customers) == 1:
|
|
||||||
# # We found a customer. Let's copy the ID
|
|
||||||
# customer = matching_customers["data"][0]
|
|
||||||
# customer_id = customer["id"]
|
|
||||||
# return customer_id
|
|
||||||
|
|
||||||
# else:
|
|
||||||
# # We didn't find anything. Create the customer
|
|
||||||
|
|
||||||
# # Create a name, since we have 2 variables which could be null
|
|
||||||
# name = expand_name(first_name, last_name)
|
|
||||||
# cast = {"email": email}
|
|
||||||
# if name:
|
|
||||||
# cast["name"] = name
|
|
||||||
# customer = stripe.Customer.create(**cast)
|
|
||||||
# logger.info(f"Created new Stripe customer {customer.id} with email {email}")
|
|
||||||
|
|
||||||
# return customer.id
|
|
||||||
|
|
||||||
|
|
||||||
# def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None):
|
|
||||||
# """
|
|
||||||
# Update the customer fields in Stripe.
|
|
||||||
# """
|
|
||||||
# if email:
|
|
||||||
# stripe.Customer.modify(stripe_id, email=email)
|
|
||||||
# logger.info(f"Modified Stripe customer {stripe_id} to have email {email}")
|
|
||||||
# if first_name or last_name:
|
|
||||||
# name = expand_name(first_name, last_name)
|
|
||||||
# stripe.Customer.modify(stripe_id, name=name)
|
|
||||||
# logger.info(f"Modified Stripe customer {stripe_id} to have email {name}")
|
|
|
@ -1,9 +1,10 @@
|
||||||
# Generated by Django 4.1.7 on 2023-02-24 13:18
|
# Generated by Django 4.1.7 on 2023-02-24 13:18
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
import django.db.models.deletion
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Generated by Django 4.1.7 on 2023-02-24 13:21
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('core', '0002_account_activemanagementpolicy_assetgroup_hook_and_more'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='user',
|
||||||
|
name='customer_id',
|
||||||
|
field=models.UUIDField(blank=True, default=uuid.uuid4, null=True),
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Generated by Django 4.1.7 on 2023-02-24 16:09
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('core', '0003_user_customer_id'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='user',
|
||||||
|
name='stripe_id',
|
||||||
|
field=models.CharField(blank=True, max_length=255, null=True),
|
||||||
|
),
|
||||||
|
]
|
|
@ -1,8 +1,8 @@
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
# import stripe
|
import stripe
|
||||||
# from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth.models import AbstractUser
|
from django.contrib.auth.models import AbstractUser
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from core.exchanges.fake import FakeExchange
|
||||||
from core.exchanges.oanda import OANDAExchange
|
from core.exchanges.oanda import OANDAExchange
|
||||||
|
|
||||||
# from core.lib.customers import get_or_create, update_customer_fields
|
# from core.lib.customers import get_or_create, update_customer_fields
|
||||||
|
from core.lib import billing
|
||||||
from core.util import logs
|
from core.util import logs
|
||||||
|
|
||||||
log = logs.get_logger(__name__)
|
log = logs.get_logger(__name__)
|
||||||
|
@ -93,21 +94,36 @@ ADJUST_CLOSE_NOTIFY_CHOICES = (
|
||||||
|
|
||||||
class User(AbstractUser):
|
class User(AbstractUser):
|
||||||
# Stripe customer ID
|
# Stripe customer ID
|
||||||
# unique_id = models.UUIDField(
|
stripe_id = models.CharField(max_length=255, null=True, blank=True)
|
||||||
# default=uuid.uuid4,
|
customer_id = models.UUIDField(default=uuid.uuid4, null=True, blank=True)
|
||||||
# )
|
|
||||||
# stripe_id = models.CharField(max_length=255, null=True, blank=True)
|
|
||||||
billing_provider_id = models.CharField(max_length=255, null=True, blank=True)
|
billing_provider_id = models.CharField(max_length=255, null=True, blank=True)
|
||||||
# last_payment = models.DateTimeField(null=True, blank=True)
|
# last_payment = models.DateTimeField(null=True, blank=True)
|
||||||
# plans = models.ManyToManyField(Plan, blank=True)
|
# plans = models.ManyToManyField(Plan, blank=True)
|
||||||
email = models.EmailField(unique=True)
|
email = models.EmailField(unique=True)
|
||||||
|
|
||||||
# def __init__(self, *args, **kwargs):
|
def delete(self, *args, **kwargs):
|
||||||
# super().__init__(*args, **kwargs)
|
if settings.BILLING_ENABLED:
|
||||||
# self._original = self
|
if self.stripe_id:
|
||||||
|
stripe.Customer.delete(self.stripe_id)
|
||||||
|
log.info(f"Deleted Stripe customer {self.stripe_id}")
|
||||||
|
super().delete(*args, **kwargs)
|
||||||
|
|
||||||
# Override save to update attributes in Lago
|
# Override save to update attributes in Lago
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
|
if self.customer_id is None:
|
||||||
|
self.customer_id = uuid.uuid4()
|
||||||
|
|
||||||
|
if settings.BILLING_ENABLED:
|
||||||
|
if not self.stripe_id: # stripe ID not stored
|
||||||
|
self.stripe_id = billing.get_or_create(
|
||||||
|
self.email, self.first_name, self.last_name
|
||||||
|
)
|
||||||
|
|
||||||
|
billing_id = billing.create_or_update_customer(self)
|
||||||
|
self.billing_provider_id = billing_id
|
||||||
|
|
||||||
|
billing.update_customer_fields(self)
|
||||||
|
|
||||||
super().save(*args, **kwargs)
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
def get_notification_settings(self):
|
def get_notification_settings(self):
|
||||||
|
|
|
@ -288,7 +288,7 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% if settings.BILLING_ENABLED %}
|
{% if settings.BILLING_ENABLED %}
|
||||||
{% if user.is_authenticated %}
|
{% if user.is_authenticated %}
|
||||||
<a class="navbar-item" href="{# url 'billing' #}">
|
<a class="navbar-item" href="{% url 'billing' %}">
|
||||||
Billing
|
Billing
|
||||||
</a>
|
</a>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
|
@ -10,20 +10,6 @@
|
||||||
</span>
|
</span>
|
||||||
<span class="tag">{{ user.first_name }} {{ user.last_name }}</span>
|
<span class="tag">{{ user.first_name }} {{ user.last_name }}</span>
|
||||||
</a>
|
</a>
|
||||||
<a class="panel-block">
|
|
||||||
<span class="panel-icon">
|
|
||||||
<i class="fas fa-binary" aria-hidden="true"></i>
|
|
||||||
</span>
|
|
||||||
{% for plan in user.plans.all %}
|
|
||||||
<span class="tag">{{ plan.name }}</span>
|
|
||||||
{% endfor %}
|
|
||||||
</a>
|
|
||||||
<a class="panel-block">
|
|
||||||
<span class="panel-icon">
|
|
||||||
<i class="fas fa-credit-card" aria-hidden="true"></i>
|
|
||||||
</span>
|
|
||||||
<span class="tag">{{ user.last_payment }}</span>
|
|
||||||
</a>
|
|
||||||
<a class="panel-block" href="{% url 'portal' %}">
|
<a class="panel-block" href="{% url 'portal' %}">
|
||||||
<span class="panel-icon">
|
<span class="panel-icon">
|
||||||
<i class="fa-brands fa-stripe-s" aria-hidden="true"></i>
|
<i class="fa-brands fa-stripe-s" aria-hidden="true"></i>
|
||||||
|
|
|
@ -413,8 +413,6 @@ class ActiveManagementLiveTestCase(ElasticMock, StrategyMixin, LiveBase, TestCas
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
print("ACTIONS", self.ams.actions)
|
|
||||||
print("EXP", expected)
|
|
||||||
|
|
||||||
self.assertEqual(self.ams.actions, expected)
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import stripe
|
import stripe
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||||
from django.http import JsonResponse
|
|
||||||
from django.shortcuts import redirect, render
|
from django.shortcuts import redirect, render
|
||||||
from django.urls import reverse, reverse_lazy
|
from django.urls import reverse, reverse_lazy
|
||||||
from django.views import View
|
from django.views import View
|
||||||
|
@ -14,7 +12,7 @@ from core.forms import NewUserForm
|
||||||
|
|
||||||
# from core.lib.products import assemble_plan_map
|
# from core.lib.products import assemble_plan_map
|
||||||
# from core.models import Plan, Session
|
# from core.models import Plan, Session
|
||||||
from core.lib import billing
|
# from core.lib import billing
|
||||||
from core.lib.notify import raw_sendmsg
|
from core.lib.notify import raw_sendmsg
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -29,16 +27,13 @@ class Home(LoginRequiredMixin, View):
|
||||||
return render(request, self.template_name)
|
return render(request, self.template_name)
|
||||||
|
|
||||||
|
|
||||||
# class Billing(LoginRequiredMixin, View):
|
class Billing(LoginRequiredMixin, View):
|
||||||
# template_name = "billing.html"
|
template_name = "billing.html"
|
||||||
|
|
||||||
# async def get(self, request):
|
def get(self, request):
|
||||||
# if not settings.STRIPE_ENABLED:
|
if not settings.BILLING_ENABLED:
|
||||||
# return redirect(reverse("home"))
|
return redirect(reverse("home"))
|
||||||
# plans = await sync_to_async(list)(Plan.objects.all())
|
return render(request, self.template_name)
|
||||||
# user_plans = await sync_to_async(list)(request.user.plans.all())
|
|
||||||
# context = {"plans": plans, "user_plans": user_plans}
|
|
||||||
# return render(request, self.template_name, context)
|
|
||||||
|
|
||||||
|
|
||||||
# class Order(LoginRequiredMixin, View):
|
# class Order(LoginRequiredMixin, View):
|
||||||
|
@ -110,12 +105,13 @@ class Signup(CreateView):
|
||||||
return super().get(request, *args, **kwargs)
|
return super().get(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# class Portal(LoginRequiredMixin, View):
|
class Portal(LoginRequiredMixin, View):
|
||||||
# async def get(self, request):
|
def get(self, request):
|
||||||
# if not settings.STRIPE_ENABLED:
|
if not settings.BILLING_ENABLED:
|
||||||
# return redirect(reverse("home"))
|
return redirect(reverse("home"))
|
||||||
# session = stripe.billing_portal.Session.create(
|
|
||||||
# customer=request.user.stripe_id,
|
session = stripe.billing_portal.Session.create(
|
||||||
# return_url=request.build_absolute_uri(reverse("billing")),
|
customer=request.user.stripe_id,
|
||||||
# )
|
return_url=request.build_absolute_uri(reverse("billing")),
|
||||||
# return redirect(session.url)
|
)
|
||||||
|
return redirect(session.url)
|
||||||
|
|
Loading…
Reference in New Issue