Compare commits

..

No commits in common. "ac4c248175cfbb3f91cb4a075d46e6f1749129a5" and "c6dd0ff2863f45edb722c13475382b4450f4aecb" have entirely different histories.

11 changed files with 268 additions and 273 deletions

View File

@ -13,7 +13,7 @@ ALLOWED_HOSTS = getenv("ALLOWED_HOSTS", f"127.0.0.1,{DOMAIN}").split(",")
CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",") CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS", URL).split(",")
# Stripe # Stripe
BILLING_ENABLED = getenv("BILLING_ENABLED", "false").lower() in trues STRIPE_ENABLED = getenv("STRIPE_ENABLED", "false").lower() in trues
STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues STRIPE_TEST = getenv("STRIPE_TEST", "true").lower() in trues
STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "") STRIPE_API_KEY_TEST = getenv("STRIPE_API_KEY_TEST", "")
STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "") STRIPE_PUBLIC_API_KEY_TEST = getenv("STRIPE_PUBLIC_API_KEY_TEST", "")
@ -56,4 +56,4 @@ if DEBUG:
"10.0.2.2", "10.0.2.2",
] ]
SETTINGS_EXPORT = ["BILLING_ENABLED", "URL", "HOOK_PATH", "ASSET_PATH"] SETTINGS_EXPORT = ["STRIPE_ENABLED", "URL", "HOOK_PATH", "ASSET_PATH"]

View File

@ -38,25 +38,24 @@ from core.views import (
strategies, strategies,
trades, trades,
) )
from core.views.stripe_callbacks import Callback
# from core.views.stripe_callbacks import Callback
urlpatterns = [ urlpatterns = [
path("__debug__/", include("debug_toolbar.urls")), path("__debug__/", include("debug_toolbar.urls")),
path("", base.Home.as_view(), name="home"), path("", base.Home.as_view(), name="home"),
# path("callback", Callback.as_view(), name="callback"), path("callback", Callback.as_view(), name="callback"),
# path("billing/", base.Billing.as_view(), name="billing"), path("billing/", base.Billing.as_view(), name="billing"),
# path("order/<str:plan_name>/", base.Order.as_view(), name="order"), path("order/<str:plan_name>/", base.Order.as_view(), name="order"),
# path( path(
# "cancel_subscription/<str:plan_name>/", "cancel_subscription/<str:plan_name>/",
# base.Cancel.as_view(), base.Cancel.as_view(),
# name="cancel_subscription", name="cancel_subscription",
# ), ),
# path( path(
# "success/", TemplateView.as_view(template_name="success.html"), name="success" "success/", TemplateView.as_view(template_name="success.html"), name="success"
# ), ),
# path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"), path("cancel/", TemplateView.as_view(template_name="cancel.html"), name="cancel"),
# path("portal", base.Portal.as_view(), name="portal"), path("portal", base.Portal.as_view(), name="portal"),
path("sapp/", admin.site.urls), path("sapp/", admin.site.urls),
# 2FA login urls # 2FA login urls
path("", include(tf_urls)), path("", include(tf_urls)),

View File

@ -2,13 +2,15 @@ from django.contrib import admin
from django.contrib.auth.admin import UserAdmin from django.contrib.auth.admin import UserAdmin
from .forms import CustomUserCreationForm from .forms import CustomUserCreationForm
from .models import ( # AssetRestriction,; Plan,; Session, from .models import ( # AssetRestriction,
Account, Account,
AssetGroup, AssetGroup,
Callback, Callback,
Hook, Hook,
NotificationSettings, NotificationSettings,
Plan,
RiskModel, RiskModel,
Session,
Signal, Signal,
Strategy, Strategy,
Trade, Trade,
@ -25,7 +27,7 @@ from .models import ( # AssetRestriction,; Plan,; Session,
# Register your models here. # Register your models here.
class CustomUserAdmin(UserAdmin): class CustomUserAdmin(UserAdmin):
# list_filter = ["plans"] list_filter = ["plans"]
model = User model = User
add_form = CustomUserCreationForm add_form = CustomUserCreationForm
fieldsets = ( fieldsets = (
@ -38,7 +40,7 @@ class CustomUserAdmin(UserAdmin):
"Payment information", "Payment information",
{ {
"fields": ( "fields": (
# "plans", "plans",
"last_payment", "last_payment",
) )
}, },
@ -99,8 +101,8 @@ class AssetGroupAdmin(admin.ModelAdmin):
admin.site.register(User, CustomUserAdmin) admin.site.register(User, CustomUserAdmin)
# admin.site.register(Plan) admin.site.register(Plan)
# admin.site.register(Session) admin.site.register(Session)
admin.site.register(Account, AccountAdmin) admin.site.register(Account, AccountAdmin)
admin.site.register(Hook, HookAdmin) admin.site.register(Hook, HookAdmin)

View File

@ -1,11 +1,4 @@
from django.conf import settings # from lago_python_client import Client
from lago_python_client import Client # from django.conf import settings
from lago_python_client.models import (
Charge,
Charges,
Customer,
CustomerBillingConfiguration,
Plan,
)
client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL) # client = Client(api_key=settings.LAGO_API_KEY, api_url=settings.LAGO_URL)

View File

@ -1,65 +1,65 @@
# import logging import logging
# import stripe import stripe
# logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# def expand_name(first_name, last_name): def expand_name(first_name, last_name):
# """ """
# Convert two name variables into one. Convert two name variables into one.
# Last name without a first name is ignored. Last name without a first name is ignored.
# """ """
# name = None name = None
# if first_name: if first_name:
# name = first_name name = first_name
# # We only want to put the last name if we have a first name # We only want to put the last name if we have a first name
# if last_name: if last_name:
# name += f" {last_name}" name += f" {last_name}"
# return name return name
# def get_or_create(email, first_name, last_name): def get_or_create(email, first_name, last_name):
# """ """
# Get a customer ID from Stripe if one with the given email exists. Get a customer ID from Stripe if one with the given email exists.
# Create a customer if one does not. Create a customer if one does not.
# Raise an exception if two or more customers matching the given email exist. Raise an exception if two or more customers matching the given email exist.
# """ """
# # Let's see if we're just missing the ID # Let's see if we're just missing the ID
# matching_customers = stripe.Customer.list(email=email, limit=2) matching_customers = stripe.Customer.list(email=email, limit=2)
# if len(matching_customers) == 2: if len(matching_customers) == 2:
# # Something is horribly wrong # Something is horribly wrong
# logger.error(f"Two customers found for email {email}") logger.error(f"Two customers found for email {email}")
# raise Exception(f"Two customers found for email {email}") raise Exception(f"Two customers found for email {email}")
# elif len(matching_customers) == 1: elif len(matching_customers) == 1:
# # We found a customer. Let's copy the ID # We found a customer. Let's copy the ID
# customer = matching_customers["data"][0] customer = matching_customers["data"][0]
# customer_id = customer["id"] customer_id = customer["id"]
# return customer_id return customer_id
# else: else:
# # We didn't find anything. Create the customer # We didn't find anything. Create the customer
# # Create a name, since we have 2 variables which could be null # Create a name, since we have 2 variables which could be null
# name = expand_name(first_name, last_name) name = expand_name(first_name, last_name)
# cast = {"email": email} cast = {"email": email}
# if name: if name:
# cast["name"] = name cast["name"] = name
# customer = stripe.Customer.create(**cast) customer = stripe.Customer.create(**cast)
# logger.info(f"Created new Stripe customer {customer.id} with email {email}") logger.info(f"Created new Stripe customer {customer.id} with email {email}")
# return customer.id return customer.id
# def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None): def update_customer_fields(stripe_id, email=None, first_name=None, last_name=None):
# """ """
# Update the customer fields in Stripe. Update the customer fields in Stripe.
# """ """
# if email: if email:
# stripe.Customer.modify(stripe_id, email=email) stripe.Customer.modify(stripe_id, email=email)
# logger.info(f"Modified Stripe customer {stripe_id} to have email {email}") logger.info(f"Modified Stripe customer {stripe_id} to have email {email}")
# if first_name or last_name: if first_name or last_name:
# name = expand_name(first_name, last_name) name = expand_name(first_name, last_name)
# stripe.Customer.modify(stripe_id, name=name) stripe.Customer.modify(stripe_id, name=name)
# logger.info(f"Modified Stripe customer {stripe_id} to have email {name}") logger.info(f"Modified Stripe customer {stripe_id} to have email {name}")

View File

@ -1,21 +1,21 @@
# from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
# from core.models import Plan from core.models import Plan
# async def assemble_plan_map(product_id_filter=None): async def assemble_plan_map(product_id_filter=None):
# """ """
# Get all the plans from the database and create an object Stripe wants. Get all the plans from the database and create an object Stripe wants.
# """ """
# line_items = [] line_items = []
# for plan in await sync_to_async(list)(Plan.objects.all()): for plan in await sync_to_async(list)(Plan.objects.all()):
# if product_id_filter: if product_id_filter:
# if plan.product_id != product_id_filter: if plan.product_id != product_id_filter:
# continue continue
# line_items.append( line_items.append(
# { {
# "price": plan.product_id, "price": plan.product_id,
# "quantity": 1, "quantity": 1,
# } }
# ) )
# return line_items return line_items

View File

@ -1,4 +1,5 @@
{% load static %} {% load static %}
{% load has_plan %}
{% load cache %} {% load cache %}
<!DOCTYPE html> <!DOCTYPE html>
@ -286,9 +287,9 @@
</div> </div>
</div> </div>
{% endif %} {% endif %}
{% if settings.BILLING_ENABLED %} {% if settings.STRIPE_ENABLED %}
{% if user.is_authenticated %} {% if user.is_authenticated %}
<a class="navbar-item" href="{# url 'billing' #}"> <a class="navbar-item" href="{% url 'billing' %}">
Billing Billing
</a> </a>
{% endif %} {% endif %}

View File

@ -0,0 +1,11 @@
from django import template
register = template.Library()
@register.filter
def has_plan(user, plan_name):
if not hasattr(user, "plans"):
return False
plan_list = [plan.name for plan in user.plans.all()]
return plan_name in plan_list

View File

@ -5,15 +5,6 @@ from django.test import TestCase
from core.models import TradingTime, User from core.models import TradingTime, User
class ModelTestCase(TestCase):
def setUp(self):
# Create a test user
self.user = User.objects.create_user(
username="testuser",
email="testuser@example.com",
)
class MarketTestCase(TestCase): class MarketTestCase(TestCase):
def setUp(self): def setUp(self):
# Create a test user # Create a test user

View File

@ -11,11 +11,9 @@ from django.views import View
from django.views.generic.edit import CreateView from django.views.generic.edit import CreateView
from core.forms import NewUserForm from core.forms import NewUserForm
# from core.lib.products import assemble_plan_map
# from core.models import Plan, Session
from core.lib import billing
from core.lib.notify import raw_sendmsg from core.lib.notify import raw_sendmsg
from core.lib.products import assemble_plan_map
from core.models import Plan, Session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,64 +27,64 @@ class Home(LoginRequiredMixin, View):
return render(request, self.template_name) return render(request, self.template_name)
# class Billing(LoginRequiredMixin, View): class Billing(LoginRequiredMixin, View):
# template_name = "billing.html" template_name = "billing.html"
# async def get(self, request): async def get(self, request):
# if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
# return redirect(reverse("home")) return redirect(reverse("home"))
# plans = await sync_to_async(list)(Plan.objects.all()) plans = await sync_to_async(list)(Plan.objects.all())
# user_plans = await sync_to_async(list)(request.user.plans.all()) user_plans = await sync_to_async(list)(request.user.plans.all())
# context = {"plans": plans, "user_plans": user_plans} context = {"plans": plans, "user_plans": user_plans}
# return render(request, self.template_name, context) return render(request, self.template_name, context)
# class Order(LoginRequiredMixin, View): class Order(LoginRequiredMixin, View):
# async def get(self, request, plan_name): async def get(self, request, plan_name):
# if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
# return redirect(reverse("home")) return redirect(reverse("home"))
# plan = Plan.objects.get(name=plan_name) plan = Plan.objects.get(name=plan_name)
# try: try:
# cast = { cast = {
# "payment_method_types": settings.ALLOWED_PAYMENT_METHODS, "payment_method_types": settings.ALLOWED_PAYMENT_METHODS,
# "mode": "subscription", "mode": "subscription",
# "customer": request.user.stripe_id, "customer": request.user.stripe_id,
# "line_items": await assemble_plan_map( "line_items": await assemble_plan_map(
# product_id_filter=plan.product_id product_id_filter=plan.product_id
# ), ),
# "success_url": request.build_absolute_uri(reverse("success")), "success_url": request.build_absolute_uri(reverse("success")),
# "cancel_url": request.build_absolute_uri(reverse("cancel")), "cancel_url": request.build_absolute_uri(reverse("cancel")),
# } }
# if request.user.is_superuser: if request.user.is_superuser:
# cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}] cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}]
# session = stripe.checkout.Session.create(**cast) session = stripe.checkout.Session.create(**cast)
# await Session.objects.acreate(user=request.user, session=session.id) await Session.objects.acreate(user=request.user, session=session.id)
# return redirect(session.url) return redirect(session.url)
# # return JsonResponse({'id': session.id}) # return JsonResponse({'id': session.id})
# except Exception as e: except Exception as e:
# # Raise a server error # Raise a server error
# return JsonResponse({"error": str(e)}, status=500) return JsonResponse({"error": str(e)}, status=500)
# class Cancel(LoginRequiredMixin, View): class Cancel(LoginRequiredMixin, View):
# async def get(self, request, plan_name): async def get(self, request, plan_name):
# if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
# return redirect(reverse("home")) return redirect(reverse("home"))
# plan = Plan.objects.get(name=plan_name) plan = Plan.objects.get(name=plan_name)
# try: try:
# subscriptions = stripe.Subscription.list( subscriptions = stripe.Subscription.list(
# customer=request.user.stripe_id, price=plan.product_id customer=request.user.stripe_id, price=plan.product_id
# ) )
# for subscription in subscriptions["data"]: for subscription in subscriptions["data"]:
# items = subscription["items"]["data"] items = subscription["items"]["data"]
# for item in items: for item in items:
# stripe.Subscription.delete(item["subscription"]) stripe.Subscription.delete(item["subscription"])
# return render(request, "subscriptioncancel.html", {"plan": plan}) return render(request, "subscriptioncancel.html", {"plan": plan})
# # return JsonResponse({'id': session.id}) # return JsonResponse({'id': session.id})
# except Exception as e: except Exception as e:
# # Raise a server error # Raise a server error
# logging.error(f"Error cancelling subscription for user: {e}") logging.error(f"Error cancelling subscription for user: {e}")
# return JsonResponse({"error": "True"}, status=500) return JsonResponse({"error": "True"}, status=500)
class Signup(CreateView): class Signup(CreateView):
@ -110,12 +108,12 @@ class Signup(CreateView):
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
# class Portal(LoginRequiredMixin, View): class Portal(LoginRequiredMixin, View):
# async def get(self, request): async def get(self, request):
# if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
# return redirect(reverse("home")) return redirect(reverse("home"))
# session = stripe.billing_portal.Session.create( session = stripe.billing_portal.Session.create(
# customer=request.user.stripe_id, customer=request.user.stripe_id,
# return_url=request.build_absolute_uri(reverse("billing")), return_url=request.build_absolute_uri(reverse("billing")),
# ) )
# return redirect(session.url) return redirect(session.url)

View File

@ -1,104 +1,104 @@
# import logging import logging
# from datetime import datetime from datetime import datetime
# import stripe import stripe
# from django.conf import settings from django.conf import settings
# from django.http import HttpResponse, JsonResponse from django.http import HttpResponse, JsonResponse
# from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
# from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
# from rest_framework.views import APIView from rest_framework.views import APIView
# from core.models import Plan, Session, User from core.models import Plan, Session, User
# logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# class Callback(APIView): class Callback(APIView):
# parser_classes = [JSONParser] parser_classes = [JSONParser]
# # TODO: make async # TODO: make async
# @csrf_exempt @csrf_exempt
# def post(self, request): def post(self, request):
# payload = request.body payload = request.body
# sig_header = request.META["HTTP_STRIPE_SIGNATURE"] sig_header = request.META["HTTP_STRIPE_SIGNATURE"]
# try: try:
# stripe.Webhook.construct_event( stripe.Webhook.construct_event(
# payload, sig_header, settings.STRIPE_ENDPOINT_SECRET payload, sig_header, settings.STRIPE_ENDPOINT_SECRET
# ) )
# except ValueError: except ValueError:
# # Invalid payload # Invalid payload
# logger.error("Invalid payload") logger.error("Invalid payload")
# return HttpResponse(status=400) return HttpResponse(status=400)
# except stripe.error.SignatureVerificationError: except stripe.error.SignatureVerificationError:
# # Invalid signature # Invalid signature
# logger.error("Invalid signature") logger.error("Invalid signature")
# return HttpResponse(status=400) return HttpResponse(status=400)
# if request.data is None: if request.data is None:
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# if "type" in request.data.keys(): if "type" in request.data.keys():
# rtype = request.data["type"] rtype = request.data["type"]
# if rtype == "checkout.session.completed": if rtype == "checkout.session.completed":
# session = request.data["data"]["object"]["id"] session = request.data["data"]["object"]["id"]
# subscription_id = request.data["data"]["object"]["subscription"] subscription_id = request.data["data"]["object"]["subscription"]
# session_map = Session.objects.get(session=session) session_map = Session.objects.get(session=session)
# if not session_map: if not session_map:
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# user = session_map.user user = session_map.user
# session_map.subscription_id = subscription_id session_map.subscription_id = subscription_id
# session_map.save() session_map.save()
# if rtype == "customer.subscription.updated": if rtype == "customer.subscription.updated":
# stripe_id = request.data["data"]["object"]["customer"] stripe_id = request.data["data"]["object"]["customer"]
# if not stripe_id: if not stripe_id:
# logging.error("No stripe id") logging.error("No stripe id")
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# user = User.objects.get(stripe_id=stripe_id) user = User.objects.get(stripe_id=stripe_id)
# # ssubscription_active # ssubscription_active
# subscription_id = request.data["data"]["object"]["id"] subscription_id = request.data["data"]["object"]["id"]
# sessions = Session.objects.filter(user=user) sessions = Session.objects.filter(user=user)
# session = None session = None
# for session_iter in sessions: for session_iter in sessions:
# if session_iter.subscription_id == subscription_id: if session_iter.subscription_id == subscription_id:
# session = session_iter session = session_iter
# if not session: if not session:
# logging.error( logging.error(
# f"No session found for subscription id {subscription_id}" f"No session found for subscription id {subscription_id}"
# ) )
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# # query Session objects # query Session objects
# # iterate and check against product_id # iterate and check against product_id
# session.request = request.data["request"]["id"] session.request = request.data["request"]["id"]
# product_id = request.data["data"]["object"]["plan"]["id"] product_id = request.data["data"]["object"]["plan"]["id"]
# plan = Plan.objects.get(product_id=product_id) plan = Plan.objects.get(product_id=product_id)
# if not plan: if not plan:
# logging.error(f"Plan not found: {product_id}") logging.error(f"Plan not found: {product_id}")
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# session.plan = plan session.plan = plan
# session.save() session.save()
# elif rtype == "payment_intent.succeeded": elif rtype == "payment_intent.succeeded":
# customer = request.data["data"]["object"]["customer"] customer = request.data["data"]["object"]["customer"]
# user = User.objects.get(stripe_id=customer) user = User.objects.get(stripe_id=customer)
# if not user: if not user:
# logging.error(f"No user found for customer: {customer}") logging.error(f"No user found for customer: {customer}")
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# session = Session.objects.get(request=request.data["request"]["id"]) session = Session.objects.get(request=request.data["request"]["id"])
# user.plans.add(session.plan) user.plans.add(session.plan)
# user.last_payment = datetime.utcnow() user.last_payment = datetime.utcnow()
# user.save() user.save()
# elif rtype == "customer.subscription.deleted": elif rtype == "customer.subscription.deleted":
# customer = request.data["data"]["object"]["customer"] customer = request.data["data"]["object"]["customer"]
# user = User.objects.get(stripe_id=customer) user = User.objects.get(stripe_id=customer)
# if not user: if not user:
# logging.error(f"No user found for customer {customer}") logging.error(f"No user found for customer {customer}")
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# product_id = request.data["data"]["object"]["plan"]["id"] product_id = request.data["data"]["object"]["plan"]["id"]
# plan = Plan.objects.get(product_id=product_id) plan = Plan.objects.get(product_id=product_id)
# user.plans.remove(plan) user.plans.remove(plan)
# user.save() user.save()
# else: else:
# return JsonResponse({"success": False}, status=500) return JsonResponse({"success": False}, status=500)
# return JsonResponse({"success": True}) return JsonResponse({"success": True})