Implement subscription management and ordering

This commit is contained in:
2022-07-21 13:48:18 +01:00
parent e0390f383c
commit a30e2afdd1
12 changed files with 170 additions and 22 deletions

View File

@@ -1,5 +1,10 @@
import logging
import stripe
logger = logging.getLogger(__name__)
def expand_name(first_name, last_name):
"""
Convert two name variables into one.
@@ -13,6 +18,7 @@ def expand_name(first_name, last_name):
name += f" {last_name}"
return name
def get_or_create(email, first_name, last_name):
"""
Get a customer ID from Stripe if one with the given email exists.
@@ -23,7 +29,7 @@ def get_or_create(email, first_name, last_name):
matching_customers = stripe.Customer.list(email=email, limit=2)
if len(matching_customers) == 2:
# Something is horribly wrong
print("Two customers match!")
logger.error(f"Two customers found for email {email}")
raise Exception(f"Two customers found for email {email}")
elif len(matching_customers) == 1:
@@ -41,9 +47,11 @@ def get_or_create(email, first_name, last_name):
if name:
cast["name"] = name
customer = stripe.Customer.create(**cast)
logger.info(f"Created new Stripe customer {customer.id} with email {email}")
return customer.id
def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None):
"""
Update the customer fields in Stripe.
@@ -52,7 +60,9 @@ def update_customer_fields(stripe_id, email=None, first_name=None, last_name=Non
if email:
print("Email modified")
stripe.Customer.modify(stripe_id, email=email)
logger.info(f"Modified Stripe customer {stripe_id} to have email {email}")
if first_name or last_name:
print("Name modified")
name = expand_name(first_name, last_name)
stripe.Customer.modify(stripe_id, name=name)
stripe.Customer.modify(stripe_id, name=name)
logger.info(f"Modified Stripe customer {stripe_id} to have email {name}")

16
core/lib/products.py Normal file
View File

@@ -0,0 +1,16 @@
from core.models import Plan
def assemble_plan_map(product_id_filter=None):
"""
Get all the plans from the database and create an object Stripe wants.
"""
line_items = []
for plan in Plan.objects.all():
if product_id_filter:
if plan.product_id != product_id_filter:
continue
line_items.append({
"price": plan.product_id,
"quantity": 1,
})
return line_items

View File

@@ -0,0 +1,23 @@
# Generated by Django 4.0.6 on 2022-07-09 09:41
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='plan',
name='product_id',
field=models.CharField(blank=True, max_length=255, null=True, unique=True),
),
migrations.AlterField(
model_name='user',
name='email',
field=models.EmailField(max_length=254, unique=True),
),
]

View File

@@ -0,0 +1,21 @@
# Generated by Django 4.0.6 on 2022-07-09 15:20
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0002_alter_plan_product_id_alter_user_email'),
]
operations = [
migrations.CreateModel(
name='Session',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('email', models.EmailField(max_length=254, unique=True)),
('session', models.CharField(max_length=255)),
],
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 4.0.6 on 2022-07-09 15:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0003_session'),
]
operations = [
migrations.AlterField(
model_name='session',
name='email',
field=models.EmailField(max_length=254),
),
]

View File

@@ -1,14 +1,19 @@
import logging
import stripe
from django.contrib.auth.models import AbstractUser
from django.db import models
from core.lib.customers import get_or_create, update_customer_fields
logger = logging.getLogger(__name__)
class Plan(models.Model):
name = models.CharField(max_length=255, unique=True)
description = models.CharField(max_length=1024, null=True, blank=True)
cost = models.IntegerField()
product_id = models.UUIDField(null=True, blank=True)
product_id = models.CharField(max_length=255, unique=True, null=True, blank=True)
image = models.CharField(max_length=1024, null=True, blank=True)
def __str__(self):
@@ -23,11 +28,11 @@ class User(AbstractUser):
last_payment = models.DateTimeField(null=True, blank=True)
paid = models.BooleanField(null=True, blank=True)
plans = models.ManyToManyField(Plan, blank=True)
email = models.EmailField(unique=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._original = self
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._original = self
def save(self, *args, **kwargs):
"""
@@ -48,8 +53,19 @@ def __init__(self, *args, **kwargs):
super().save(*args, **kwargs)
def delete(self, *args, **kwargs):
if self.stripe_id:
stripe.Customer.delete(self.stripe_id)
logger.info(f"Deleted Stripe customer {self.stripe_id}")
super().delete(*args, **kwargs)
def has_plan(self, plan):
if not self.paid: # We can't have any plans if we haven't paid
return False
plan_list = [plan.name for plan in self.plans.all()]
return plan in plan_list
class Session(models.Model):
email = models.EmailField()
session = models.CharField(max_length=255)

View File

@@ -34,6 +34,9 @@
Last payment
</li>
</ul>
<form action="{% url 'portal' %}">
<input class="btn btn-lg btn-dark btn-block" type="submit" value="Subscription management">
</form>
</div>
</div>
@@ -46,6 +49,7 @@
</div>
</div>
</div>
</div>
</div>

View File

@@ -1,10 +1,10 @@
{% load static %}
{% for plan in plans %}
{% if plan not in user_plans %}
<a href="order?product={{ plan.name }}">
<a href="/order/{{ plan.product_id}}">
{% endif %}
<div class="product">
<img src="{% static plan.image %}" alt="Data image"/>
<div class="description">
@@ -15,10 +15,10 @@
{% endif %}
</div>
</div>
{% if plan not in user_plans %}
</a>
{% endif %}
{% endfor %}
<script type="text/javascript">

View File

@@ -1,12 +1,11 @@
{% extends "base.html" %}
{% block content %}
<head>
<title>Thanks for your order!</title>
<link rel="stylesheet" href="style.css">
</head>
<section>
<p>We appreciate your business!</p>
<p>Download details will be available <a href="{{ url_for('main.profile') }} ">in your profile</a> shortly.</p>
</section>
<h1 class="title">Pathogen Data Insights</h1>
<div class="container">
<h2 class="subtitle">Thank you for your order!</h2>
<div class="col">
<h2 class="subtitle">The customer portal will be available <a href="{% url 'billing' %} ">in your profile</a> shortly.</h2>
</div>
</div>
{% endblock %}

View File

@@ -1,11 +1,15 @@
import stripe
from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin
from django.shortcuts import render
from django.urls import reverse_lazy
from django.http import JsonResponse
from django.shortcuts import redirect, render
from django.urls import reverse, reverse_lazy
from django.views import View
from django.views.generic.edit import CreateView
from core.forms import NewUserForm
from core.models import Plan
from core.lib.products import assemble_plan_map
from core.models import Plan, Session
# Create your views here
# fmt: off
@@ -28,7 +32,35 @@ class Billing(LoginRequiredMixin, View):
return render(request, self.template_name, context)
class Order(LoginRequiredMixin, View):
def get(self, request, product_id):
try:
session = stripe.checkout.Session.create(
payment_method_types=settings.ALLOWED_PAYMENT_METHODS,
mode='subscription',
customer=request.user.stripe_id,
line_items=assemble_plan_map(product_id_filter=product_id),
success_url=request.build_absolute_uri(reverse("success")),
cancel_url=request.build_absolute_uri(reverse("cancel")),
)
Session.objects.create(email=request.user.email, session=session.id)
return redirect(session.url)
# return JsonResponse({'id': session.id})
except Exception as e:
# Raise a server error
return JsonResponse({"error": str(e)}, status=500)
class Signup(CreateView):
form_class = NewUserForm
success_url = reverse_lazy("login")
template_name = "registration/signup.html"
class Portal(LoginRequiredMixin, View):
def get(self, request):
session = stripe.billing_portal.Session.create(
customer=request.user.stripe_id,
return_url=request.build_absolute_uri(),
)
return redirect(session.url)