Begin implementing billing

This commit is contained in:
Mark Veidemanis 2023-02-24 07:20:31 +00:00
parent 0937f7299a
commit ac4c248175
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
10 changed files with 268 additions and 261 deletions

View File

@ -13,7 +13,7 @@ ALLOWED_HOSTS = getenv("ALLOWED_HOSTS", f"127.0.0.1,{DOMAIN}").split(",")
CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",") CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",")
# Stripe # Stripe
STRIPE_ENABLED = getenv("STRIPE_ENABLED", "false").lower() in trues BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues
STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues
STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "") STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "")
STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "") STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "")
@ -56,4 +56,4 @@ if DEBUG:
"10.0.2.2", "10.0.2.2",
] ]
SETTINGS_EXPORT = ["STRIPE_ENABLED", "URL", "HOOK_PATH", "ASSET_PATH"] SETTINGS_EXPORT = ["BILLING_ENABLED", "URL", "HOOK_PATH", "ASSET_PATH"]

View File

@ -38,24 +38,25 @@ from core.views import (
strategies, strategies,
trades, trades,
) )
from core.views.stripe_callbacks import Callback
# from core.views.stripe_callbacks import Callback
urlpatterns = [ urlpatterns = [
path("__debug__/", include("debug_toolbar.urls")), path("__debug__/", include("debug_toolbar.urls")),
path("", base.Home.as_view(), name="home"), path("", base.Home.as_view(), name="home"),
path("callback", Callback.as_view(), name="callback"), # path("callback", Callback.as_view(), name="callback"),
path("billing/", base.Billing.as_view(), name="billing"), # path("billing/", base.Billing.as_view(), name="billing"),
path("order/<str:plan_name>/", base.Order.as_view(), name="order"), # path("order/<str:plan_name>/", base.Order.as_view(), name="order"),
path( # path(
"cancel_subscription/<str:plan_name>/", # "cancel_subscription/<str:plan_name>/",
base.Cancel.as_view(), # base.Cancel.as_view(),
name="cancel_subscription", # name="cancel_subscription",
), # ),
path( # path(
"success/", TemplateView.as_view(template_name="success.html"), name="success" # "success/", TemplateView.as_view(template_name="success.html"), name="success"
), # ),
path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"), # path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"),
path("portal", base.Portal.as_view(), name="portal"), # path("portal", base.Portal.as_view(), name="portal"),
path("sapp/", admin.site.urls), path("sapp/", admin.site.urls),
# 2FA login urls # 2FA login urls
path("", include(tf_urls)), path("", include(tf_urls)),

View File

@ -1,4 +1,11 @@
# from lago_python_client import Client from django.conf import settings
# from django.conf import settings from lago_python_client import Client
from lago_python_client.models import (
Charge,
Charges,
Customer,
CustomerBillingConfiguration,
Plan,
)
# client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL) client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL)

View File

@ -1,65 +1,65 @@
import logging # import logging
import stripe # import stripe
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
def expand_name(first_name, last_name): # def expand_name(first_name, last_name):
""" # """
Convert two name variables into one. # Convert two name variables into one.
Last name without a first name is ignored. # Last name without a first name is ignored.
""" # """
name = None # name = None
if first_name: # if first_name:
name = first_name # name = first_name
# We only want to put the last name if we have a first name # # We only want to put the last name if we have a first name
if last_name: # if last_name:
name += f" {last_name}" # name += f" {last_name}"
return name # return name
def get_or_create(email, first_name, last_name): # def get_or_create(email, first_name, last_name):
""" # """
Get a customer ID from Stripe if one with the given email exists. # Get a customer ID from Stripe if one with the given email exists.
Create a customer if one does not. # Create a customer if one does not.
Raise an exception if two or more customers matching the given email exist. # Raise an exception if two or more customers matching the given email exist.
""" # """
# Let's see if we're just missing the ID # # Let's see if we're just missing the ID
matching_customers = stripe.Customer.list(email=email, limit=2) # matching_customers = stripe.Customer.list(email=email, limit=2)
if len(matching_customers) == 2: # if len(matching_customers) == 2:
# Something is horribly wrong # # Something is horribly wrong
logger.error(f"Two customers found for email {email}") # logger.error(f"Two customers found for email {email}")
raise Exception(f"Two customers found for email {email}") # raise Exception(f"Two customers found for email {email}")
elif len(matching_customers) == 1: # elif len(matching_customers) == 1:
# We found a customer. Let's copy the ID # # We found a customer. Let's copy the ID
customer = matching_customers["data"][0] # customer = matching_customers["data"][0]
customer_id = customer["id"] # customer_id = customer["id"]
return customer_id # return customer_id
else: # else:
# We didn't find anything. Create the customer # # We didn't find anything. Create the customer
# Create a name, since we have 2 variables which could be null # # Create a name, since we have 2 variables which could be null
name = expand_name(first_name, last_name) # name = expand_name(first_name, last_name)
cast = {"email": email} # cast = {"email": email}
if name: # if name:
cast["name"] = name # cast["name"] = name
customer = stripe.Customer.create(**cast) # customer = stripe.Customer.create(**cast)
logger.info(f"Created new Stripe customer {customer.id} with email {email}") # logger.info(f"Created new Stripe customer {customer.id} with email {email}")
return customer.id # return customer.id
def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None): # def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None):
""" # """
Update the customer fields in Stripe. # Update the customer fields in Stripe.
""" # """
if email: # if email:
stripe.Customer.modify(stripe_id, email=email) # stripe.Customer.modify(stripe_id, email=email)
logger.info(f"Modified Stripe customer {stripe_id} to have email {email}") # logger.info(f"Modified Stripe customer {stripe_id} to have email {email}")
if first_name or last_name: # if first_name or last_name:
name = expand_name(first_name, last_name) # name = expand_name(first_name, last_name)
stripe.Customer.modify(stripe_id, name=name) # stripe.Customer.modify(stripe_id, name=name)
logger.info(f"Modified Stripe customer {stripe_id} to have email {name}") # logger.info(f"Modified Stripe customer {stripe_id} to have email {name}")

View File

@ -1,21 +1,21 @@
from asgiref.sync import sync_to_async # from asgiref.sync import sync_to_async
from core.models import Plan # from core.models import Plan
async def assemble_plan_map(product_id_filter=None): # async def assemble_plan_map(product_id_filter=None):
""" # """
Get all the plans from the database and create an object Stripe wants. # Get all the plans from the database and create an object Stripe wants.
""" # """
line_items = [] # line_items = []
for plan in await sync_to_async(list)(Plan.objects.all()): # for plan in await sync_to_async(list)(Plan.objects.all()):
if product_id_filter: # if product_id_filter:
if plan.product_id != product_id_filter: # if plan.product_id != product_id_filter:
continue # continue
line_items.append( # line_items.append(
{ # {
"price": plan.product_id, # "price": plan.product_id,
"quantity": 1, # "quantity": 1,
} # }
) # )
return line_items # return line_items

View File

@ -1,5 +1,4 @@
{% load static %} {% load static %}
{% load has_plan %}
{% load cache %} {% load cache %}
<!DOCTYPE html> <!DOCTYPE html>
@ -287,9 +286,9 @@
</div> </div>
</div> </div>
{% endif %} {% endif %}
{% if settings.STRIPE_ENABLED %} {% if settings.BILLING_ENABLED %}
{% if user.is_authenticated %} {% if user.is_authenticated %}
<a class="navbar-item" href="{% url 'billing' %}"> <a class="navbar-item" href="{# url 'billing' #}">
Billing Billing
</a> </a>
{% endif %} {% endif %}

View File

@ -1,11 +0,0 @@
from django import template
register = template.Library()
@register.filter
def has_plan(user, plan_name):
if not hasattr(user, "plans"):
return False
plan_list = [plan.name for plan in user.plans.all()]
return plan_name in plan_list

View File

@ -5,6 +5,15 @@ from django.test import TestCase
from core.models import TradingTime, User from core.models import TradingTime, User
class ModelTestCase(TestCase):
def setUp(self):
# Create a test user
self.user = User.objects.create_user(
username="testuser",
email="testuser@example.com",
)
class MarketTestCase(TestCase): class MarketTestCase(TestCase):
def setUp(self): def setUp(self):
# Create a test user # Create a test user

View File

@ -11,9 +11,11 @@ from django.views import View
from django.views.generic.edit import CreateView from django.views.generic.edit import CreateView
from core.forms import NewUserForm from core.forms import NewUserForm
# from core.lib.products import assemble_plan_map
# from core.models import Plan, Session
from core.lib import billing
from core.lib.notify import raw_sendmsg from core.lib.notify import raw_sendmsg
from core.lib.products import assemble_plan_map
from core.models import Plan, Session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,64 +29,64 @@ class Home(LoginRequiredMixin, View):
return render(request, self.template_name) return render(request, self.template_name)
class Billing(LoginRequiredMixin, View): # class Billing(LoginRequiredMixin, View):
template_name = "billing.html" # template_name = "billing.html"
async def get(self, request): # async def get(self, request):
if not settings.STRIPE_ENABLED: # if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) # return redirect(reverse("home"))
plans = await sync_to_async(list)(Plan.objects.all()) # plans = await sync_to_async(list)(Plan.objects.all())
user_plans = await sync_to_async(list)(request.user.plans.all()) # user_plans = await sync_to_async(list)(request.user.plans.all())
context = {"plans": plans, "user_plans": user_plans} # context = {"plans": plans, "user_plans": user_plans}
return render(request, self.template_name, context) # return render(request, self.template_name, context)
class Order(LoginRequiredMixin, View): # class Order(LoginRequiredMixin, View):
async def get(self, request, plan_name): # async def get(self, request, plan_name):
if not settings.STRIPE_ENABLED: # if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) # return redirect(reverse("home"))
plan = Plan.objects.get(name=plan_name) # plan = Plan.objects.get(name=plan_name)
try: # try:
cast = { # cast = {
"payment_method_types": settings.ALLOWED_PAYMENT_METHODS, # "payment_method_types": settings.ALLOWED_PAYMENT_METHODS,
"mode": "subscription", # "mode": "subscription",
"customer": request.user.stripe_id, # "customer": request.user.stripe_id,
"line_items": await assemble_plan_map( # "line_items": await assemble_plan_map(
product_id_filter=plan.product_id # product_id_filter=plan.product_id
), # ),
"success_url": request.build_absolute_uri(reverse("success")), # "success_url": request.build_absolute_uri(reverse("success")),
"cancel_url": request.build_absolute_uri(reverse("cancel")), # "cancel_url": request.build_absolute_uri(reverse("cancel")),
} # }
if request.user.is_superuser: # if request.user.is_superuser:
cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}] # cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}]
session = stripe.checkout.Session.create(**cast) # session = stripe.checkout.Session.create(**cast)
await Session.objects.acreate(user=request.user, session=session.id) # await Session.objects.acreate(user=request.user, session=session.id)
return redirect(session.url) # return redirect(session.url)
# return JsonResponse({'id': session.id}) # # return JsonResponse({'id': session.id})
except Exception as e: # except Exception as e:
# Raise a server error # # Raise a server error
return JsonResponse({"error": str(e)}, status=500) # return JsonResponse({"error": str(e)}, status=500)
class Cancel(LoginRequiredMixin, View): # class Cancel(LoginRequiredMixin, View):
async def get(self, request, plan_name): # async def get(self, request, plan_name):
if not settings.STRIPE_ENABLED: # if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) # return redirect(reverse("home"))
plan = Plan.objects.get(name=plan_name) # plan = Plan.objects.get(name=plan_name)
try: # try:
subscriptions = stripe.Subscription.list( # subscriptions = stripe.Subscription.list(
customer=request.user.stripe_id, price=plan.product_id # customer=request.user.stripe_id, price=plan.product_id
) # )
for subscription in subscriptions["data"]: # for subscription in subscriptions["data"]:
items = subscription["items"]["data"] # items = subscription["items"]["data"]
for item in items: # for item in items:
stripe.Subscription.delete(item["subscription"]) # stripe.Subscription.delete(item["subscription"])
return render(request, "subscriptioncancel.html", {"plan": plan}) # return render(request, "subscriptioncancel.html", {"plan": plan})
# return JsonResponse({'id': session.id}) # # return JsonResponse({'id': session.id})
except Exception as e: # except Exception as e:
# Raise a server error # # Raise a server error
logging.error(f"Error cancelling subscription for user: {e}") # logging.error(f"Error cancelling subscription for user: {e}")
return JsonResponse({"error": "True"}, status=500) # return JsonResponse({"error": "True"}, status=500)
class Signup(CreateView): class Signup(CreateView):
@ -108,12 +110,12 @@ class Signup(CreateView):
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
class Portal(LoginRequiredMixin, View): # class Portal(LoginRequiredMixin, View):
async def get(self, request): # async def get(self, request):
if not settings.STRIPE_ENABLED: # if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) # return redirect(reverse("home"))
session = stripe.billing_portal.Session.create( # session = stripe.billing_portal.Session.create(
customer=request.user.stripe_id, # customer=request.user.stripe_id,
return_url=request.build_absolute_uri(reverse("billing")), # return_url=request.build_absolute_uri(reverse("billing")),
) # )
return redirect(session.url) # return redirect(session.url)

View File

@ -1,104 +1,104 @@
import logging # import logging
from datetime import datetime # from datetime import datetime
import stripe # import stripe
from django.conf import settings # from django.conf import settings
from django.http import HttpResponse, JsonResponse # from django.http import HttpResponse, JsonResponse
from django.views.decorators.csrf import csrf_exempt # from django.views.decorators.csrf import csrf_exempt
from rest_framework.parsers import JSONParser # from rest_framework.parsers import JSONParser
from rest_framework.views import APIView # from rest_framework.views import APIView
from core.models import Plan, Session, User # from core.models import Plan, Session, User
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
class Callback(APIView): # class Callback(APIView):
parser_classes = [JSONParser] # parser_classes = [JSONParser]
# TODO: make async # # TODO: make async
@csrf_exempt # @csrf_exempt
def post(self, request): # def post(self, request):
payload = request.body # payload = request.body
sig_header = request.META["HTTP_STRIPE_SIGNATURE"] # sig_header = request.META["HTTP_STRIPE_SIGNATURE"]
try: # try:
stripe.Webhook.construct_event( # stripe.Webhook.construct_event(
payload, sig_header, settings.STRIPE_ENDPOINT_SECRET # payload, sig_header, settings.STRIPE_ENDPOINT_SECRET
) # )
except ValueError: # except ValueError:
# Invalid payload # # Invalid payload
logger.error("Invalid payload") # logger.error("Invalid payload")
return HttpResponse(status=400) # return HttpResponse(status=400)
except stripe.error.SignatureVerificationError: # except stripe.error.SignatureVerificationError:
# Invalid signature # # Invalid signature
logger.error("Invalid signature") # logger.error("Invalid signature")
return HttpResponse(status=400) # return HttpResponse(status=400)
if request.data is None: # if request.data is None:
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
if "type" in request.data.keys(): # if "type" in request.data.keys():
rtype = request.data["type"] # rtype = request.data["type"]
if rtype == "checkout.session.completed": # if rtype == "checkout.session.completed":
session = request.data["data"]["object"]["id"] # session = request.data["data"]["object"]["id"]
subscription_id = request.data["data"]["object"]["subscription"] # subscription_id = request.data["data"]["object"]["subscription"]
session_map = Session.objects.get(session=session) # session_map = Session.objects.get(session=session)
if not session_map: # if not session_map:
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
user = session_map.user # user = session_map.user
session_map.subscription_id = subscription_id # session_map.subscription_id = subscription_id
session_map.save() # session_map.save()
if rtype == "customer.subscription.updated": # if rtype == "customer.subscription.updated":
stripe_id = request.data["data"]["object"]["customer"] # stripe_id = request.data["data"]["object"]["customer"]
if not stripe_id: # if not stripe_id:
logging.error("No stripe id") # logging.error("No stripe id")
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
user = User.objects.get(stripe_id=stripe_id) # user = User.objects.get(stripe_id=stripe_id)
# ssubscription_active # # ssubscription_active
subscription_id = request.data["data"]["object"]["id"] # subscription_id = request.data["data"]["object"]["id"]
sessions = Session.objects.filter(user=user) # sessions = Session.objects.filter(user=user)
session = None # session = None
for session_iter in sessions: # for session_iter in sessions:
if session_iter.subscription_id == subscription_id: # if session_iter.subscription_id == subscription_id:
session = session_iter # session = session_iter
if not session: # if not session:
logging.error( # logging.error(
f"No session found for subscription id {subscription_id}" # f"No session found for subscription id {subscription_id}"
) # )
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
# query Session objects # # query Session objects
# iterate and check against product_id # # iterate and check against product_id
session.request = request.data["request"]["id"] # session.request = request.data["request"]["id"]
product_id = request.data["data"]["object"]["plan"]["id"] # product_id = request.data["data"]["object"]["plan"]["id"]
plan = Plan.objects.get(product_id=product_id) # plan = Plan.objects.get(product_id=product_id)
if not plan: # if not plan:
logging.error(f"Plan not found: {product_id}") # logging.error(f"Plan not found: {product_id}")
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
session.plan = plan # session.plan = plan
session.save() # session.save()
elif rtype == "payment_intent.succeeded": # elif rtype == "payment_intent.succeeded":
customer = request.data["data"]["object"]["customer"] # customer = request.data["data"]["object"]["customer"]
user = User.objects.get(stripe_id=customer) # user = User.objects.get(stripe_id=customer)
if not user: # if not user:
logging.error(f"No user found for customer: {customer}") # logging.error(f"No user found for customer: {customer}")
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
session = Session.objects.get(request=request.data["request"]["id"]) # session = Session.objects.get(request=request.data["request"]["id"])
user.plans.add(session.plan) # user.plans.add(session.plan)
user.last_payment = datetime.utcnow() # user.last_payment = datetime.utcnow()
user.save() # user.save()
elif rtype == "customer.subscription.deleted": # elif rtype == "customer.subscription.deleted":
customer = request.data["data"]["object"]["customer"] # customer = request.data["data"]["object"]["customer"]
user = User.objects.get(stripe_id=customer) # user = User.objects.get(stripe_id=customer)
if not user: # if not user:
logging.error(f"No user found for customer {customer}") # logging.error(f"No user found for customer {customer}")
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
product_id = request.data["data"]["object"]["plan"]["id"] # product_id = request.data["data"]["object"]["plan"]["id"]
plan = Plan.objects.get(product_id=product_id) # plan = Plan.objects.get(product_id=product_id)
user.plans.remove(plan) # user.plans.remove(plan)
user.save() # user.save()
else: # else:
return JsonResponse({"success": False}, status=500) # return JsonResponse({"success": False}, status=500)
return JsonResponse({"success": True}) # return JsonResponse({"success": True})