Remove async and update CRUD helpers

master^2
Mark Veidemanis 1 year ago
parent c54b3e5412
commit 84be2a7278
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -39,4 +39,4 @@ if DEBUG:
"10.0.2.2", "10.0.2.2",
] ]
SETTINGS_EXPORT = ["STRIPE_ENABLED"] SETTINGS_EXPORT = ["STRIPE_ENABLED"]

@ -1,14 +1,12 @@
from asgiref.sync import sync_to_async
from core.models import Plan from core.models import Plan
async def assemble_plan_map(product_id_filter=None): def assemble_plan_map(product_id_filter=None):
""" """
Get all the plans from the database and create an object Stripe wants. Get all the plans from the database and create an object Stripe wants.
""" """
line_items = [] line_items = []
for plan in await sync_to_async(list)(Plan.objects.all()): for plan in Plan.objects.all():
if product_id_filter: if product_id_filter:
if plan.product_id != product_id_filter: if plan.product_id != product_id_filter:
continue continue

@ -1,5 +1,8 @@
import uuid import uuid
from django.core.exceptions import ImproperlyConfigured
from django.core.paginator import Paginator
from django.db.models import QuerySet
from django.http import Http404, HttpResponseBadRequest from django.http import Http404, HttpResponseBadRequest
from django.urls import reverse from django.urls import reverse
from django.views.generic.detail import DetailView from django.views.generic.detail import DetailView
@ -12,20 +15,81 @@ from core.util import logs
log = logs.get_logger(__name__) log = logs.get_logger(__name__)
class ObjectList(ListView): class RestrictedViewMixin:
"""
This mixin overrides two helpers in order to pass the user object to the filters.
get_queryset alters the objects returned for list views.
get_form_kwargs passes the request object to the form class. Remaining permissions
checks are in forms.py
"""
allow_empty = True
queryset = None
model = None
paginate_by = None
paginate_orphans = 0
context_object_name = None
paginator_class = Paginator
page_kwarg = "page"
ordering = None
def get_queryset(self):
"""
This function is overriden to filter the objects by the requesting user.
"""
if self.queryset is not None:
queryset = self.queryset
if isinstance(queryset, QuerySet):
# queryset = queryset.all()
queryset = queryset.filter(user=self.request.user)
elif self.model is not None:
queryset = self.model._default_manager.filter(user=self.request.user)
else:
raise ImproperlyConfigured(
"%(cls)s is missing a QuerySet. Define "
"%(cls)s.model, %(cls)s.queryset, or override "
"%(cls)s.get_queryset()." % {"cls": self.__class__.__name__}
)
if hasattr(self, "get_ordering"):
ordering = self.get_ordering()
if ordering:
if isinstance(ordering, str):
ordering = (ordering,)
queryset = queryset.order_by(*ordering)
return queryset
def get_form_kwargs(self):
"""Passes the request object to the form class.
This is necessary to only display members that belong to a given user"""
kwargs = super().get_form_kwargs()
kwargs["request"] = self.request
return kwargs
class ObjectNameMixin(object):
def __init__(self, *args, **kwargs):
self.title_singular = self.model._meta.verbose_name.title() # Hook
self.context_object_name_singular = self.title_singular.lower() # hook
self.title = self.model._meta.verbose_name_plural.title() # Hooks
self.context_object_name = self.title.lower() # hooks
self.context_object_name = self.context_object_name.replace(" ", "")
self.context_object_name_singular = self.context_object_name_singular.replace(
" ", ""
)
super().__init__(*args, **kwargs)
class ObjectList(RestrictedViewMixin, ObjectNameMixin, ListView):
allowed_types = ["modal", "widget", "window", "page"] allowed_types = ["modal", "widget", "window", "page"]
window_content = "window-content/objects.html" window_content = "window-content/objects.html"
list_template = None list_template = None
model = None
context_object_name = "objects"
context_object_name_singular = "object"
page_title = None page_title = None
page_subtitle = None page_subtitle = None
title = "Objects"
title_singular = "Object"
list_url_name = None list_url_name = None
# WARNING: TAKEN FROM locals() # WARNING: TAKEN FROM locals()
list_url_args = ["type"] list_url_args = ["type"]
@ -36,6 +100,7 @@ class ObjectList(ListView):
# copied from BaseListView # copied from BaseListView
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
self.request = request
self.object_list = self.get_queryset() self.object_list = self.get_queryset()
allow_empty = self.get_allow_empty() allow_empty = self.get_allow_empty()
@ -51,6 +116,7 @@ class ObjectList(ListView):
for arg in self.list_url_args: for arg in self.list_url_args:
list_url_args[arg] = locals()[arg] list_url_args[arg] = locals()[arg]
orig_type = type
if type == "page": if type == "page":
type = "modal" type = "modal"
@ -87,17 +153,19 @@ class ObjectList(ListView):
# Return partials for HTMX # Return partials for HTMX
if self.request.htmx: if self.request.htmx:
self.template_name = self.list_template if orig_type == "page":
self.template_name = self.list_template
else:
context["window_content"] = self.list_template
return self.render_to_response(context) return self.render_to_response(context)
class ObjectCreate(CreateView): class ObjectCreate(RestrictedViewMixin, ObjectNameMixin, CreateView):
allowed_types = ["modal", "widget", "window", "page"] allowed_types = ["modal", "widget", "window", "page"]
window_content = "window-content/object-form.html" window_content = "window-content/object-form.html"
parser_classes = [FormParser] parser_classes = [FormParser]
model = None model = None
context_object_name = "objects"
submit_url_name = None submit_url_name = None
list_url_name = None list_url_name = None
@ -122,8 +190,13 @@ class ObjectCreate(CreateView):
response["HX-Trigger"] = f"{self.context_object_name_singular}Event" response["HX-Trigger"] = f"{self.context_object_name_singular}Event"
return response return response
def form_invalid(self, form):
"""If the form is invalid, render the invalid form."""
return self.get(self.request, **self.kwargs, form=form)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
self.request = request self.request = request
self.kwargs = kwargs
type = kwargs.get("type", None) type = kwargs.get("type", None)
if not type: if not type:
return HttpResponseBadRequest("No type specified") return HttpResponseBadRequest("No type specified")
@ -144,6 +217,9 @@ class ObjectCreate(CreateView):
list_url = reverse(self.list_url_name, kwargs=list_url_args) list_url = reverse(self.list_url_name, kwargs=list_url_args)
context = self.get_context_data() context = self.get_context_data()
form = kwargs.get("form", None)
if form:
context["form"] = form
context["unique"] = unique context["unique"] = unique
context["window_content"] = self.window_content context["window_content"] = self.window_content
context["context_object_name"] = self.context_object_name context["context_object_name"] = self.context_object_name
@ -151,7 +227,9 @@ class ObjectCreate(CreateView):
context["submit_url"] = submit_url context["submit_url"] = submit_url
context["list_url"] = list_url context["list_url"] = list_url
context["type"] = type context["type"] = type
return self.render_to_response(context) response = self.render_to_response(context)
# response["HX-Trigger"] = f"{self.context_object_name_singular}Event"
return response
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
self.request = request self.request = request
@ -159,21 +237,19 @@ class ObjectCreate(CreateView):
return super().post(request, *args, **kwargs) return super().post(request, *args, **kwargs)
class ObjectRead(DetailView): class ObjectRead(RestrictedViewMixin, ObjectNameMixin, DetailView):
allowed_types = ["modal", "widget", "window", "page"] allowed_types = ["modal", "widget", "window", "page"]
window_content = "window-content/object.html" window_content = "window-content/object.html"
model = None model = None
context_object_name = "object"
class ObjectUpdate(UpdateView): class ObjectUpdate(RestrictedViewMixin, ObjectNameMixin, UpdateView):
allowed_types = ["modal", "widget", "window", "page"] allowed_types = ["modal", "widget", "window", "page"]
window_content = "window-content/object-form.html" window_content = "window-content/object-form.html"
parser_classes = [FormParser] parser_classes = [FormParser]
model = None model = None
context_object_name = "objects"
submit_url_name = None submit_url_name = None
request = None request = None
@ -193,6 +269,10 @@ class ObjectUpdate(UpdateView):
response["HX-Trigger"] = f"{self.context_object_name_singular}Event" response["HX-Trigger"] = f"{self.context_object_name_singular}Event"
return response return response
def form_invalid(self, form):
"""If the form is invalid, render the invalid form."""
return self.get(self.request, **self.kwargs, form=form)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
self.request = request self.request = request
type = kwargs.get("type", None) type = kwargs.get("type", None)
@ -211,13 +291,18 @@ class ObjectUpdate(UpdateView):
self.object = self.get_object() self.object = self.get_object()
submit_url = reverse(self.submit_url_name, kwargs={"type": type, "pk": pk}) submit_url = reverse(self.submit_url_name, kwargs={"type": type, "pk": pk})
context = self.get_context_data() context = self.get_context_data()
form = kwargs.get("form", None)
if form:
context["form"] = form
context["unique"] = unique context["unique"] = unique
context["window_content"] = self.window_content context["window_content"] = self.window_content
context["context_object_name"] = self.context_object_name context["context_object_name"] = self.context_object_name
context["context_object_name_singular"] = self.context_object_name_singular context["context_object_name_singular"] = self.context_object_name_singular
context["submit_url"] = submit_url context["submit_url"] = submit_url
context["type"] = type context["type"] = type
return self.render_to_response(context) response = self.render_to_response(context)
# response["HX-Trigger"] = f"{self.context_object_name_singular}Event"
return response
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
self.request = request self.request = request
@ -225,9 +310,8 @@ class ObjectUpdate(UpdateView):
return super().post(request, *args, **kwargs) return super().post(request, *args, **kwargs)
class ObjectDelete(DeleteView): class ObjectDelete(RestrictedViewMixin, ObjectNameMixin, DeleteView):
model = None model = None
context_object_name_singular = "object"
template_name = "partials/notify.html" template_name = "partials/notify.html"
# Overriden to prevent success URL from being used # Overriden to prevent success URL from being used

@ -1,7 +1,6 @@
import logging import logging
import stripe import stripe
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import JsonResponse from django.http import JsonResponse
@ -22,24 +21,24 @@ logger = logging.getLogger(__name__)
class Home(View): class Home(View):
template_name = "index.html" template_name = "index.html"
async def get(self, request): def get(self, request):
return render(request, self.template_name) return render(request, self.template_name)
class Billing(LoginRequiredMixin, View): class Billing(LoginRequiredMixin, View):
template_name = "billing.html" template_name = "billing.html"
async def get(self, request): def get(self, request):
if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) return redirect(reverse("home"))
plans = await sync_to_async(list)(Plan.objects.all()) plans = Plan.objects.all()
user_plans = await sync_to_async(list)(request.user.plans.all()) user_plans = request.user.plans.all()
context = {"plans": plans, "user_plans": user_plans} context = {"plans": plans, "user_plans": user_plans}
return render(request, self.template_name, context) return render(request, self.template_name, context)
class Order(LoginRequiredMixin, View): class Order(LoginRequiredMixin, View):
async def get(self, request, plan_name): def get(self, request, plan_name):
if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) return redirect(reverse("home"))
plan = Plan.objects.get(name=plan_name) plan = Plan.objects.get(name=plan_name)
@ -48,16 +47,14 @@ class Order(LoginRequiredMixin, View):
"payment_method_types": settings.ALLOWED_PAYMENT_METHODS, "payment_method_types": settings.ALLOWED_PAYMENT_METHODS,
"mode": "subscription", "mode": "subscription",
"customer": request.user.stripe_id, "customer": request.user.stripe_id,
"line_items": await assemble_plan_map( "line_items": assemble_plan_map(product_id_filter=plan.product_id),
product_id_filter=plan.product_id
),
"success_url": request.build_absolute_uri(reverse("success")), "success_url": request.build_absolute_uri(reverse("success")),
"cancel_url": request.build_absolute_uri(reverse("cancel")), "cancel_url": request.build_absolute_uri(reverse("cancel")),
} }
if request.user.is_superuser: if request.user.is_superuser:
cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}] cast["discounts"] = [{"coupon": settings.STRIPE_ADMIN_COUPON}]
session = stripe.checkout.Session.create(**cast) session = stripe.checkout.Session.create(**cast)
await Session.objects.acreate(user=request.user, session=session.id) Session.objects.create(user=request.user, session=session.id)
return redirect(session.url) return redirect(session.url)
# return JsonResponse({'id': session.id}) # return JsonResponse({'id': session.id})
except Exception as e: except Exception as e:
@ -66,7 +63,7 @@ class Order(LoginRequiredMixin, View):
class Cancel(LoginRequiredMixin, View): class Cancel(LoginRequiredMixin, View):
async def get(self, request, plan_name): def get(self, request, plan_name):
if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) return redirect(reverse("home"))
plan = Plan.objects.get(name=plan_name) plan = Plan.objects.get(name=plan_name)
@ -98,7 +95,7 @@ class Signup(CreateView):
class Portal(LoginRequiredMixin, View): class Portal(LoginRequiredMixin, View):
async def get(self, request): def get(self, request):
if not settings.STRIPE_ENABLED: if not settings.STRIPE_ENABLED:
return redirect(reverse("home")) return redirect(reverse("home"))
session = stripe.billing_portal.Session.create( session = stripe.billing_portal.Session.create(

@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
class Callback(APIView): class Callback(APIView):
parser_classes = [JSONParser] parser_classes = [JSONParser]
# TODO: make async
@csrf_exempt @csrf_exempt
def post(self, request): def post(self, request):
payload = request.body payload = request.body

@ -7,14 +7,14 @@ from django.views import View
class DemoModal(View): class DemoModal(View):
template_name = "modals/modal.html" template_name = "modals/modal.html"
async def get(self, request): def get(self, request):
return render(request, self.template_name) return render(request, self.template_name)
class DemoWidget(View): class DemoWidget(View):
template_name = "widgets/widget.html" template_name = "widgets/widget.html"
async def get(self, request): def get(self, request):
unique = str(uuid.uuid4())[:8] unique = str(uuid.uuid4())[:8]
return render(request, self.template_name, {"unique": unique}) return render(request, self.template_name, {"unique": unique})
@ -22,5 +22,5 @@ class DemoWidget(View):
class DemoWindow(View): class DemoWindow(View):
template_name = "windows/window.html" template_name = "windows/window.html"
async def get(self, request): def get(self, request):
return render(request, self.template_name) return render(request, self.template_name)

Loading…
Cancel
Save