From eb33ae45e7f6babdac7585511a668ca8d7868274 Mon Sep 17 00:00:00 2001
From: Immanuel Onyeka <immanuel@debian-BULLSEYE-live-builder-AMD64>
Date: Tue, 6 Feb 2024 20:49:22 -0500
Subject: [PATCH] Add subscription update endpoints to router

---
 skouter.go | 124 ++++++++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 100 insertions(+), 24 deletions(-)

diff --git a/skouter.go b/skouter.go
index 4bf8283..86bf325 100644
--- a/skouter.go
+++ b/skouter.go
@@ -31,8 +31,8 @@ import (
 	"github.com/stripe/stripe-go/v76"
 	"github.com/stripe/stripe-go/v76/customer"
 	"github.com/stripe/stripe-go/v76/subscription"
-	"github.com/stripe/stripe-go/v76/invoice"
-	"github.com/stripe/stripe-go/v76/paymentintent"
+	// "github.com/stripe/stripe-go/v76/invoice"
+	// "github.com/stripe/stripe-go/v76/paymentintent"
 	"github.com/stripe/stripe-go/v76/webhook"
 	"image"
 	_ "image/jpeg"
@@ -1529,29 +1529,25 @@ func (sub *Subscription) updateSub(db *sql.DB) error {
 	var err error
 
 	query = `UPDATE subscription
-	SET client_secret = ?, payment_status = ?,
+	SET client_secret = CASE @a := ? WHEN '' THEN client_secret ELSE @a END,
+		current_period_end = CASE
+		@b := ? WHEN 0 THEN current_period_end ELSE @b END,
+		current_period_start = CASE
+		@c := ? WHEN 0 THEN current_period_start ELSE @c END,
+		payment_status = CASE @d := ? WHEN '' THEN client_secret ELSE @d END,
+		status = CASE @e := ? WHEN '' THEN client_secret ELSE @e END
 	WHERE id = ?
 	`
 	
-	s, err := subscription.Get(sub.StripeId, &stripe.SubscriptionParams{})
-	if err != nil { return err }
-	
-	i, err := invoice.Get(s.LatestInvoice.ID, &stripe.InvoiceParams{})
-	if err != nil { return err }
-	
-	p, err := paymentintent.Get(i.PaymentIntent.ID,
-	&stripe.PaymentIntentParams{})
-	if err != nil { return err }
-	
 	_, err = db.Exec(query,
-		p.ClientSecret,
-		p.Status,
+		sub.ClientSecret,
+		sub.End,
+		sub.Start,
+		sub.PaymentStatus,
+		sub.Status,
 		sub.Id,
 	)
 	if err != nil { return err }
-	
-	sub.ClientSecret = p.ClientSecret
-	sub.PaymentStatus = string(p.Status)
 
 	return err
 }
@@ -3167,9 +3163,11 @@ func invoiceFailed(w http.ResponseWriter, db *sql.DB, r *http.Request) {
 	log.Println(event.Data)
 }
 
-// A successful subscription payment should be confirmed by Stripe and
-// Updated through this hook.
-func subCreated(w http.ResponseWriter, db *sql.DB, r *http.Request) {
+// Important for catching subscription creation through Stripe dashboard
+// although it already happens at subscribe(). It checks if the user already
+// has a subscription and replaces those fields if necessary so a seperate
+// subCreated() is not necessary.
+func subUpdated(w http.ResponseWriter, db *sql.DB, r *http.Request) {
 
 	var sub stripe.Subscription
 	b, err := io.ReadAll(r.Body)
@@ -3181,7 +3179,7 @@ func subCreated(w http.ResponseWriter, db *sql.DB, r *http.Request) {
 	
 	event, err := webhook.ConstructEvent(b,
 	r.Header.Get("Stripe-Signature"),
-	hookKeys.SubCreated)
+	hookKeys.SubUpdated)
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		log.Printf("webhook.ConstructEvent: %v", err)
@@ -3224,10 +3222,79 @@ func subCreated(w http.ResponseWriter, db *sql.DB, r *http.Request) {
 	user.Sub.PaymentStatus = string(sub.LatestInvoice.PaymentIntent.Status)
 	user.Sub.Status = string(sub.Status)
 	
-	if user.Sub.Id > 0 {
-		user.Sub.updateSub(db)
+	if user.Sub.Id != 0 {
+		err = user.Sub.insertSub(db)
 	} else {
+		user.Sub.updateSub(db)
+	}
+	
+	if err != nil {
+		http.Error(w, err.Error(), 500)
+		return
+	}
+	
+	log.Println("User subscription created:", user.Id, sub.ID)
+}
+
+// Handles changes to customer subscriptions sent by Stripe
+func subDeleted(w http.ResponseWriter, db *sql.DB, r *http.Request) {
+	var sub stripe.Subscription
+	b, err := io.ReadAll(r.Body)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusBadRequest)
+		log.Printf("io.ReadAll: %v", err)
+		return
+	}
+	
+	event, err := webhook.ConstructEvent(b,
+	r.Header.Get("Stripe-Signature"),
+	hookKeys.SubUpdated)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusBadRequest)
+		log.Printf("webhook.ConstructEvent: %v", err)
+		return
+	}
+	
+	// OK should be sent before any processing to confirm with Stripe that
+	// the hook was received
+	w.WriteHeader(http.StatusOK)
+	if event.Type != "customer.subscription.updated" {
+		log.Println(
+		"Invalid event type sent to customer.subscription.updated.")
+		return
+	}
+	
+	json.Unmarshal(event.Data.Raw, &sub)
+	log.Println(event.Type, sub.ID, sub.Customer.ID)
+	
+	user, err := queryCustomer(db, sub.Customer.ID)
+	if err != nil {
+		log.Printf("Could not query customer: %v", err)
+		return
+	}
+	
+	if statuses[user.Status] < 5 && sub.Status == "trialing" {
+		user.Status = "Trial"
+		user.update(db)
+	} else if sub.Status != "active" {
+		user.Status = "Unsubscribed"
+		user.update(db)
+	}
+	
+	user.Sub.UserId = user.Id
+	user.Sub.StripeId = sub.ID
+	user.Sub.CustomerId = user.CustomerId
+	user.Sub.PriceId = standardPriceId
+	user.Sub.End = int(sub.CurrentPeriodEnd)
+	user.Sub.Start = int(sub.CurrentPeriodStart)
+	user.Sub.ClientSecret = sub.LatestInvoice.PaymentIntent.ClientSecret
+	user.Sub.PaymentStatus = string(sub.LatestInvoice.PaymentIntent.Status)
+	user.Sub.Status = string(sub.Status)
+	
+	if user.Sub.Id != 0 {
 		err = user.Sub.insertSub(db)
+	} else {
+		user.Sub.updateSub(db)
 	}
 	
 	if err != nil {
@@ -3366,6 +3433,15 @@ func api(w http.ResponseWriter, r *http.Request) {
 	case match(p, "/api/stripe/invoice-payment-failed", &args) &&
 		r.Method == http.MethodPost:
 		invoiceFailed(w, db, r)
+	case match(p, "/api/stripe/sub-created", &args) &&
+		r.Method == http.MethodPost:
+		subUpdated(w, db, r)
+	case match(p, "/api/stripe/sub-updated", &args) &&
+		r.Method == http.MethodPost:
+		subUpdated(w, db, r)
+	case match(p, "/api/stripe/sub-updated", &args) &&
+		r.Method == http.MethodPost:
+		subUpdated(w, db, r)
 	default:
 		http.Error(w, "Invalid route or token", 404)
 	}