From ac4c248175cfbb3f91cb4a075d46e6f1749129a5 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Fri, 24 Feb 2023 07:20:31 +0000 Subject: [PATCH] Begin implementing billing --- app/local_settings.py | 4 +- app/urls.py | 29 ++--- core/lib/billing.py | 13 ++- core/lib/customers.py | 106 +++++++++---------- core/lib/products.py | 36 +++---- core/templates/base.html | 5 +- core/templatetags/has_plan.py | 11 -- core/tests/test_models.py | 9 ++ core/views/base.py | 140 +++++++++++++------------ core/views/stripe_callbacks.py | 186 ++++++++++++++++----------------- 10 files changed, 273 insertions(+), 266 deletions(-) delete mode 100644 core/templatetags/has_plan.py diff --git a/app/local_settings.py b/app/local_settings.py index ede7b22..b59ab70 100644 --- a/app/local_settings.py +++ b/app/local_settings.py @@ -13,7 +13,7 @@ ALLOWED_HOSTS = getenv("ALLOWED_HOSTS", f"127.0.0.1,{DOMAIN}").split(",") CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",") # Stripe -STRIPE_ENABLED = getenv("STRIPE_ENABLED", "false").lower() in trues +BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "") STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "") @@ -56,4 +56,4 @@ if DEBUG: "10.0.2.2", ] -SETTINGS_EXPORT = ["STRIPE_ENABLED", "URL", "HOOK_PATH", "ASSET_PATH"] +SETTINGS_EXPORT = ["BILLING_ENABLED", "URL", "HOOK_PATH", "ASSET_PATH"] diff --git a/app/urls.py b/app/urls.py index 75023cb..0198f96 100644 --- a/app/urls.py +++ b/app/urls.py @@ -38,24 +38,25 @@ from core.views import ( strategies, trades, ) -from core.views.stripe_callbacks import Callback + +# from core.views.stripe_callbacks import Callback urlpatterns = [ path("__debug__/", include("debug_toolbar.urls")), path("", base.Home.as_view(), name="home"), - path("callback", Callback.as_view(), name="callback"), - path("billing/", base.Billing.as_view(), name="billing"), - path("order//", base.Order.as_view(), name="order"), - path( - "cancel_subscription//", - base.Cancel.as_view(), - name="cancel_subscription", - ), - path( - "success/", TemplateView.as_view(template_name="success.html"), name="success" - ), - path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"), - path("portal", base.Portal.as_view(), name="portal"), + # path("callback", Callback.as_view(), name="callback"), + # path("billing/", base.Billing.as_view(), name="billing"), + # path("order//", base.Order.as_view(), name="order"), + # path( + # "cancel_subscription//", + # base.Cancel.as_view(), + # name="cancel_subscription", + # ), + # path( + # "success/", TemplateView.as_view(template_name="success.html"), name="success" + # ), + # path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"), + # path("portal", base.Portal.as_view(), name="portal"), path("sapp/", admin.site.urls), # 2FA login urls path("", include(tf_urls)), diff --git a/core/lib/billing.py b/core/lib/billing.py index c42b973..f4a0dbe 100644 --- a/core/lib/billing.py +++ b/core/lib/billing.py @@ -1,4 +1,11 @@ -# from lago_python_client import Client -# from django.conf import settings +from django.conf import settings +from lago_python_client import Client +from lago_python_client.models import ( + Charge, + Charges, + Customer, + CustomerBillingConfiguration, + Plan, +) -# client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL) +client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL) diff --git a/core/lib/customers.py b/core/lib/customers.py index 35f05a9..9e8dfee 100644 --- a/core/lib/customers.py +++ b/core/lib/customers.py @@ -1,65 +1,65 @@ -import logging +# import logging -import stripe +# import stripe -logger = logging.getLogger(__name__) +# logger = logging.getLogger(__name__) -def expand_name(first_name, last_name): - """ - Convert two name variables into one. - Last name without a first name is ignored. - """ - name = None - if first_name: - name = first_name - # We only want to put the last name if we have a first name - if last_name: - name += f" {last_name}" - return name +# def expand_name(first_name, last_name): +# """ +# Convert two name variables into one. +# Last name without a first name is ignored. +# """ +# name = None +# if first_name: +# name = first_name +# # We only want to put the last name if we have a first name +# if last_name: +# name += f" {last_name}" +# return name -def get_or_create(email, first_name, last_name): - """ - Get a customer ID from Stripe if one with the given email exists. - Create a customer if one does not. - Raise an exception if two or more customers matching the given email exist. - """ - # Let's see if we're just missing the ID - matching_customers = stripe.Customer.list(email=email, limit=2) - if len(matching_customers) == 2: - # Something is horribly wrong - logger.error(f"Two customers found for email {email}") - raise Exception(f"Two customers found for email {email}") +# def get_or_create(email, first_name, last_name): +# """ +# Get a customer ID from Stripe if one with the given email exists. +# Create a customer if one does not. +# Raise an exception if two or more customers matching the given email exist. +# """ +# # Let's see if we're just missing the ID +# matching_customers = stripe.Customer.list(email=email, limit=2) +# if len(matching_customers) == 2: +# # Something is horribly wrong +# logger.error(f"Two customers found for email {email}") +# raise Exception(f"Two customers found for email {email}") - elif len(matching_customers) == 1: - # We found a customer. Let's copy the ID - customer = matching_customers["data"][0] - customer_id = customer["id"] - return customer_id +# elif len(matching_customers) == 1: +# # We found a customer. Let's copy the ID +# customer = matching_customers["data"][0] +# customer_id = customer["id"] +# return customer_id - else: - # We didn't find anything. Create the customer +# else: +# # We didn't find anything. Create the customer - # Create a name, since we have 2 variables which could be null - name = expand_name(first_name, last_name) - cast = {"email": email} - if name: - cast["name"] = name - customer = stripe.Customer.create(**cast) - logger.info(f"Created new Stripe customer {customer.id} with email {email}") +# # Create a name, since we have 2 variables which could be null +# name = expand_name(first_name, last_name) +# cast = {"email": email} +# if name: +# cast["name"] = name +# customer = stripe.Customer.create(**cast) +# logger.info(f"Created new Stripe customer {customer.id} with email {email}") - return customer.id +# return customer.id -def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None): - """ - Update the customer fields in Stripe. - """ - if email: - stripe.Customer.modify(stripe_id, email=email) - logger.info(f"Modified Stripe customer {stripe_id} to have email {email}") - if first_name or last_name: - name = expand_name(first_name, last_name) - stripe.Customer.modify(stripe_id, name=name) - logger.info(f"Modified Stripe customer {stripe_id} to have email {name}") +# def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None): +# """ +# Update the customer fields in Stripe. +# """ +# if email: +# stripe.Customer.modify(stripe_id, email=email) +# logger.info(f"Modified Stripe customer {stripe_id} to have email {email}") +# if first_name or last_name: +# name = expand_name(first_name, last_name) +# stripe.Customer.modify(stripe_id, name=name) +# logger.info(f"Modified Stripe customer {stripe_id} to have email {name}") diff --git a/core/lib/products.py b/core/lib/products.py index 9e9e191..84f97af 100644 --- a/core/lib/products.py +++ b/core/lib/products.py @@ -1,21 +1,21 @@ -from asgiref.sync import sync_to_async +# from asgiref.sync import sync_to_async -from core.models import Plan +# from core.models import Plan -async def assemble_plan_map(product_id_filter=None): - """ - Get all the plans from the database and create an object Stripe wants. - """ - line_items = [] - for plan in await sync_to_async(list)(Plan.objects.all()): - if product_id_filter: - if plan.product_id != product_id_filter: - continue - line_items.append( - { - "price": plan.product_id, - "quantity": 1, - } - ) - return line_items +# async def assemble_plan_map(product_id_filter=None): +# """ +# Get all the plans from the database and create an object Stripe wants. +# """ +# line_items = [] +# for plan in await sync_to_async(list)(Plan.objects.all()): +# if product_id_filter: +# if plan.product_id != product_id_filter: +# continue +# line_items.append( +# { +# "price": plan.product_id, +# "quantity": 1, +# } +# ) +# return line_items diff --git a/core/templates/base.html b/core/templates/base.html index a7b9fd6..3a2bbbf 100644 --- a/core/templates/base.html +++ b/core/templates/base.html @@ -1,5 +1,4 @@ {% load static %} -{% load has_plan %} {% load cache %} @@ -287,9 +286,9 @@ {% endif %} - {% if settings.STRIPE_ENABLED %} + {% if settings.BILLING_ENABLED %} {% if user.is_authenticated %} - + Billing {% endif %} diff --git a/core/templatetags/has_plan.py b/core/templatetags/has_plan.py deleted file mode 100644 index 94a568f..0000000 --- a/core/templatetags/has_plan.py +++ /dev/null @@ -1,11 +0,0 @@ -from django import template - -register = template.Library() - - -@register.filter -def has_plan(user, plan_name): - if not hasattr(user, "plans"): - return False - plan_list = [plan.name for plan in user.plans.all()] - return plan_name in plan_list diff --git a/core/tests/test_models.py b/core/tests/test_models.py index 1a0423a..9122cf6 100644 --- a/core/tests/test_models.py +++ b/core/tests/test_models.py @@ -5,6 +5,15 @@ from django.test import TestCase from core.models import TradingTime, User +class ModelTestCase(TestCase): + def setUp(self): + # Create a test user + self.user = User.objects.create_user( + username="testuser", + email="testuser@example.com", + ) + + class MarketTestCase(TestCase): def setUp(self): # Create a test user diff --git a/core/views/base.py b/core/views/base.py index 87e800f..0405358 100644 --- a/core/views/base.py +++ b/core/views/base.py @@ -11,9 +11,11 @@ from django.views import View from django.views.generic.edit import CreateView from core.forms import NewUserForm + +# from core.lib.products import assemble_plan_map +# from core.models import Plan, Session +from core.lib import billing from core.lib.notify import raw_sendmsg -from core.lib.products import assemble_plan_map -from core.models import Plan, Session logger = logging.getLogger(__name__) @@ -27,64 +29,64 @@ class Home(LoginRequiredMixin, View): return render(request, self.template_name) -class Billing(LoginRequiredMixin, View): - template_name = "billing.html" - - async def get(self, request): - if not settings.STRIPE_ENABLED: - return redirect(reverse("home")) - plans = await sync_to_async(list)(Plan.objects.all()) - user_plans = await sync_to_async(list)(request.user.plans.all()) - context = {"plans": plans, "user_plans": user_plans} - return render(request, self.template_name, context) - - -class Order(LoginRequiredMixin, View): - async def get(self, request, plan_name): - if not settings.STRIPE_ENABLED: - return redirect(reverse("home")) - plan = Plan.objects.get(name=plan_name) - try: - cast = { - "payment_method_types": settings.ALLOWED_PAYMENT_METHODS, - "mode": "subscription", - "customer": request.user.stripe_id, - "line_items": await assemble_plan_map( - product_id_filter=plan.product_id - ), - "success_url": request.build_absolute_uri(reverse("success")), - "cancel_url": request.build_absolute_uri(reverse("cancel")), - } - if request.user.is_superuser: - cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}] - session = stripe.checkout.Session.create(**cast) - await Session.objects.acreate(user=request.user, session=session.id) - return redirect(session.url) - # return JsonResponse({'id': session.id}) - except Exception as e: - # Raise a server error - return JsonResponse({"error": str(e)}, status=500) - - -class Cancel(LoginRequiredMixin, View): - async def get(self, request, plan_name): - if not settings.STRIPE_ENABLED: - return redirect(reverse("home")) - plan = Plan.objects.get(name=plan_name) - try: - subscriptions = stripe.Subscription.list( - customer=request.user.stripe_id, price=plan.product_id - ) - for subscription in subscriptions["data"]: - items = subscription["items"]["data"] - for item in items: - stripe.Subscription.delete(item["subscription"]) - return render(request, "subscriptioncancel.html", {"plan": plan}) - # return JsonResponse({'id': session.id}) - except Exception as e: - # Raise a server error - logging.error(f"Error cancelling subscription for user: {e}") - return JsonResponse({"error": "True"}, status=500) +# class Billing(LoginRequiredMixin, View): +# template_name = "billing.html" + +# async def get(self, request): +# if not settings.STRIPE_ENABLED: +# return redirect(reverse("home")) +# plans = await sync_to_async(list)(Plan.objects.all()) +# user_plans = await sync_to_async(list)(request.user.plans.all()) +# context = {"plans": plans, "user_plans": user_plans} +# return render(request, self.template_name, context) + + +# class Order(LoginRequiredMixin, View): +# async def get(self, request, plan_name): +# if not settings.STRIPE_ENABLED: +# return redirect(reverse("home")) +# plan = Plan.objects.get(name=plan_name) +# try: +# cast = { +# "payment_method_types": settings.ALLOWED_PAYMENT_METHODS, +# "mode": "subscription", +# "customer": request.user.stripe_id, +# "line_items": await assemble_plan_map( +# product_id_filter=plan.product_id +# ), +# "success_url": request.build_absolute_uri(reverse("success")), +# "cancel_url": request.build_absolute_uri(reverse("cancel")), +# } +# if request.user.is_superuser: +# cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}] +# session = stripe.checkout.Session.create(**cast) +# await Session.objects.acreate(user=request.user, session=session.id) +# return redirect(session.url) +# # return JsonResponse({'id': session.id}) +# except Exception as e: +# # Raise a server error +# return JsonResponse({"error": str(e)}, status=500) + + +# class Cancel(LoginRequiredMixin, View): +# async def get(self, request, plan_name): +# if not settings.STRIPE_ENABLED: +# return redirect(reverse("home")) +# plan = Plan.objects.get(name=plan_name) +# try: +# subscriptions = stripe.Subscription.list( +# customer=request.user.stripe_id, price=plan.product_id +# ) +# for subscription in subscriptions["data"]: +# items = subscription["items"]["data"] +# for item in items: +# stripe.Subscription.delete(item["subscription"]) +# return render(request, "subscriptioncancel.html", {"plan": plan}) +# # return JsonResponse({'id': session.id}) +# except Exception as e: +# # Raise a server error +# logging.error(f"Error cancelling subscription for user: {e}") +# return JsonResponse({"error": "True"}, status=500) class Signup(CreateView): @@ -108,12 +110,12 @@ class Signup(CreateView): return super().get(request, *args, **kwargs) -class Portal(LoginRequiredMixin, View): - async def get(self, request): - if not settings.STRIPE_ENABLED: - return redirect(reverse("home")) - session = stripe.billing_portal.Session.create( - customer=request.user.stripe_id, - return_url=request.build_absolute_uri(reverse("billing")), - ) - return redirect(session.url) +# class Portal(LoginRequiredMixin, View): +# async def get(self, request): +# if not settings.STRIPE_ENABLED: +# return redirect(reverse("home")) +# session = stripe.billing_portal.Session.create( +# customer=request.user.stripe_id, +# return_url=request.build_absolute_uri(reverse("billing")), +# ) +# return redirect(session.url) diff --git a/core/views/stripe_callbacks.py b/core/views/stripe_callbacks.py index f193561..9e6b262 100644 --- a/core/views/stripe_callbacks.py +++ b/core/views/stripe_callbacks.py @@ -1,104 +1,104 @@ -import logging -from datetime import datetime +# import logging +# from datetime import datetime -import stripe -from django.conf import settings -from django.http import HttpResponse, JsonResponse -from django.views.decorators.csrf import csrf_exempt -from rest_framework.parsers import JSONParser -from rest_framework.views import APIView +# import stripe +# from django.conf import settings +# from django.http import HttpResponse, JsonResponse +# from django.views.decorators.csrf import csrf_exempt +# from rest_framework.parsers import JSONParser +# from rest_framework.views import APIView -from core.models import Plan, Session, User +# from core.models import Plan, Session, User -logger = logging.getLogger(__name__) +# logger = logging.getLogger(__name__) -class Callback(APIView): - parser_classes = [JSONParser] +# class Callback(APIView): +# parser_classes = [JSONParser] - # TODO: make async - @csrf_exempt - def post(self, request): - payload = request.body - sig_header = request.META["HTTP_STRIPE_SIGNATURE"] - try: - stripe.Webhook.construct_event( - payload, sig_header, settings.STRIPE_ENDPOINT_SECRET - ) - except ValueError: - # Invalid payload - logger.error("Invalid payload") - return HttpResponse(status=400) - except stripe.error.SignatureVerificationError: - # Invalid signature - logger.error("Invalid signature") - return HttpResponse(status=400) +# # TODO: make async +# @csrf_exempt +# def post(self, request): +# payload = request.body +# sig_header = request.META["HTTP_STRIPE_SIGNATURE"] +# try: +# stripe.Webhook.construct_event( +# payload, sig_header, settings.STRIPE_ENDPOINT_SECRET +# ) +# except ValueError: +# # Invalid payload +# logger.error("Invalid payload") +# return HttpResponse(status=400) +# except stripe.error.SignatureVerificationError: +# # Invalid signature +# logger.error("Invalid signature") +# return HttpResponse(status=400) - if request.data is None: - return JsonResponse({"success": False}, status=500) - if "type" in request.data.keys(): - rtype = request.data["type"] - if rtype == "checkout.session.completed": - session = request.data["data"]["object"]["id"] - subscription_id = request.data["data"]["object"]["subscription"] - session_map = Session.objects.get(session=session) - if not session_map: - return JsonResponse({"success": False}, status=500) - user = session_map.user - session_map.subscription_id = subscription_id - session_map.save() +# if request.data is None: +# return JsonResponse({"success": False}, status=500) +# if "type" in request.data.keys(): +# rtype = request.data["type"] +# if rtype == "checkout.session.completed": +# session = request.data["data"]["object"]["id"] +# subscription_id = request.data["data"]["object"]["subscription"] +# session_map = Session.objects.get(session=session) +# if not session_map: +# return JsonResponse({"success": False}, status=500) +# user = session_map.user +# session_map.subscription_id = subscription_id +# session_map.save() - if rtype == "customer.subscription.updated": - stripe_id = request.data["data"]["object"]["customer"] - if not stripe_id: - logging.error("No stripe id") - return JsonResponse({"success": False}, status=500) - user = User.objects.get(stripe_id=stripe_id) - # ssubscription_active - subscription_id = request.data["data"]["object"]["id"] - sessions = Session.objects.filter(user=user) - session = None - for session_iter in sessions: - if session_iter.subscription_id == subscription_id: - session = session_iter - if not session: - logging.error( - f"No session found for subscription id {subscription_id}" - ) - return JsonResponse({"success": False}, status=500) - # query Session objects - # iterate and check against product_id - session.request = request.data["request"]["id"] - product_id = request.data["data"]["object"]["plan"]["id"] - plan = Plan.objects.get(product_id=product_id) - if not plan: - logging.error(f"Plan not found: {product_id}") - return JsonResponse({"success": False}, status=500) - session.plan = plan - session.save() +# if rtype == "customer.subscription.updated": +# stripe_id = request.data["data"]["object"]["customer"] +# if not stripe_id: +# logging.error("No stripe id") +# return JsonResponse({"success": False}, status=500) +# user = User.objects.get(stripe_id=stripe_id) +# # ssubscription_active +# subscription_id = request.data["data"]["object"]["id"] +# sessions = Session.objects.filter(user=user) +# session = None +# for session_iter in sessions: +# if session_iter.subscription_id == subscription_id: +# session = session_iter +# if not session: +# logging.error( +# f"No session found for subscription id {subscription_id}" +# ) +# return JsonResponse({"success": False}, status=500) +# # query Session objects +# # iterate and check against product_id +# session.request = request.data["request"]["id"] +# product_id = request.data["data"]["object"]["plan"]["id"] +# plan = Plan.objects.get(product_id=product_id) +# if not plan: +# logging.error(f"Plan not found: {product_id}") +# return JsonResponse({"success": False}, status=500) +# session.plan = plan +# session.save() - elif rtype == "payment_intent.succeeded": - customer = request.data["data"]["object"]["customer"] - user = User.objects.get(stripe_id=customer) - if not user: - logging.error(f"No user found for customer: {customer}") - return JsonResponse({"success": False}, status=500) - session = Session.objects.get(request=request.data["request"]["id"]) +# elif rtype == "payment_intent.succeeded": +# customer = request.data["data"]["object"]["customer"] +# user = User.objects.get(stripe_id=customer) +# if not user: +# logging.error(f"No user found for customer: {customer}") +# return JsonResponse({"success": False}, status=500) +# session = Session.objects.get(request=request.data["request"]["id"]) - user.plans.add(session.plan) - user.last_payment = datetime.utcnow() - user.save() +# user.plans.add(session.plan) +# user.last_payment = datetime.utcnow() +# user.save() - elif rtype == "customer.subscription.deleted": - customer = request.data["data"]["object"]["customer"] - user = User.objects.get(stripe_id=customer) - if not user: - logging.error(f"No user found for customer {customer}") - return JsonResponse({"success": False}, status=500) - product_id = request.data["data"]["object"]["plan"]["id"] - plan = Plan.objects.get(product_id=product_id) - user.plans.remove(plan) - user.save() - else: - return JsonResponse({"success": False}, status=500) - return JsonResponse({"success": True}) +# elif rtype == "customer.subscription.deleted": +# customer = request.data["data"]["object"]["customer"] +# user = User.objects.get(stripe_id=customer) +# if not user: +# logging.error(f"No user found for customer {customer}") +# return JsonResponse({"success": False}, status=500) +# product_id = request.data["data"]["object"]["plan"]["id"] +# plan = Plan.objects.get(product_id=product_id) +# user.plans.remove(plan) +# user.save() +# else: +# return JsonResponse({"success": False}, status=500) +# return JsonResponse({"success": True})