diff --git a/go.mod b/go.mod index 5c117e4..b2cc096 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,8 @@ module example.com/m go 1.19 -require github.com/go-sql-driver/mysql v1.6.0 +require ( + github.com/SebastiaanKlippert/go-wkhtmltopdf v1.9.0 + github.com/go-sql-driver/mysql v1.6.0 + github.com/golang-jwt/jwt/v4 v4.5.0 +) diff --git a/go.sum b/go.sum index 20c16d6..7bb8a6e 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +github.com/SebastiaanKlippert/go-wkhtmltopdf v1.9.0 h1:DNrExYwvyyI404SxdUCCANAj9TwnGjRfa3cYFMNY1AU= +github.com/SebastiaanKlippert/go-wkhtmltopdf v1.9.0/go.mod h1:SQq4xfIdvf6WYKSDxAJc+xOJdolt+/bc1jnQKMtPMvQ= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= diff --git a/migrations/0_29092022_create_main_tables.sql b/migrations/0_29092022_create_main_tables.sql index 65d0399..ea86d4a 100644 --- a/migrations/0_29092022_create_main_tables.sql +++ b/migrations/0_29092022_create_main_tables.sql @@ -108,12 +108,13 @@ CREATE TABLE mi ( five_year_total INT DEFAULT 0, initial_premium INT DEFAULT 0, initial_rate INT DEFAULT 0, + initial_amount INT DEFAULT 0, PRIMARY KEY (`id`), FOREIGN KEY (loan_id) REFERENCES loan(id) ); /* template = true fees are saved for users or branches. If template or default - * are true, estimate_id should be null.*/ + * are true, estimate_id should be null. */ CREATE TABLE fee ( id INT AUTO_INCREMENT NOT NULL, loan_id INT, diff --git a/skouter.go b/skouter.go index 8505cd6..6f35aef 100644 --- a/skouter.go +++ b/skouter.go @@ -10,11 +10,21 @@ import ( _ "github.com/go-sql-driver/mysql" "fmt" "encoding/json" - // "io" "strconv" "bytes" + "time" + "errors" + "strings" + pdf "github.com/SebastiaanKlippert/go-wkhtmltopdf" + "github.com/golang-jwt/jwt/v4" ) +type UserClaims struct { + Id int `json:"id"` + Role string `json:"role"` + Exp string `json:"exp"` +} + type Page struct { tpl *template.Template Title string @@ -75,16 +85,16 @@ type Loan struct { } type MI struct { - Type string - Label string - Lender string - Rate float32 - Premium float32 - Upfront float32 - FiveYearTotal float32 - InitialAllInPremium float32 - InitialAllInRate float32 - InitialAmount float32 + Type string + Label string + Lender string + Rate float32 + Premium float32 + Upfront float32 + FiveYearTotal float32 + InitialAllInPremium float32 + InitialAllInRate float32 + InitialAmount float32 } type Estimate struct { @@ -118,6 +128,22 @@ var pages = map[string]Page { "app": cache("app", "App"), } +var roles = map[string]int{ + "guest": 1, + "employee": 2, + "admin": 3, +} + +// Used to validate claim in JWT token body. Checks if user id is greater than +// zero and time format is valid +func (c UserClaims) Valid() error { + if c.Id < 1 { return errors.New("Invalid id") } + t, err := time.Parse(time.UnixDate, c.Exp) + if err != nil { return err } + if t.Before(time.Now()) { return errors.New("Token expired.") } + return err +} + func cache(name string, title string) Page { var p = []string{"master.tpl", paths[name]} tpl := template.Must(template.ParseFiles(p...)) @@ -482,6 +508,7 @@ func fetchMi(db *sql.DB, estimate *Estimate, pos int) []MI { resp, err := http.DefaultClient.Do(req) var res map[string]interface{} + var result []MI if resp.StatusCode != 200 { log.Printf("the status: %v\nthe resp: %v\n the req: %v\n the body: %v\n", @@ -489,40 +516,116 @@ func fetchMi(db *sql.DB, estimate *Estimate, pos int) []MI { } else { json.NewDecoder(resp.Body).Decode(&res) // estimate.Loans[pos].Mi = res + // Parse res into result here + } + + return result +} + +func login(w http.ResponseWriter, db *sql.DB, r *http.Request) { + var id int + var role string + var err error + r.ParseForm() + + row := db.QueryRow( + `SELECT id, role FROM user WHERE email = ? AND password = sha2(?, 256)`, + r.PostFormValue("email"), r.PostFormValue("password")) + + err = row.Scan(&id, &role) + if err != nil { + http.Error(w, "Invalid Credentials.", http.StatusUnauthorized) + return } - return estimate + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + UserClaims{ Id: id, Role: role, + Exp: time.Now().Add(time.Minute * 30).Format(time.UnixDate)}) + + tokenStr, err := token.SignedString([]byte(config["JWT_SECRET"])) + if err != nil { + log.Println("Token could not be signed: ", err, tokenStr) + http.Error(w, "Token generation error.", http.StatusInternalServerError) + return + } + + cookie := http.Cookie{Name: "hound", + Value: tokenStr, + Path: "/", + Expires: time.Now().Add(time.Hour * 24)} + http.SetCookie(w, &cookie) + _, err = w.Write([]byte(tokenStr)) + if err != nil { + http.Error(w, + "Could not complete token write.", + http.StatusInternalServerError)} +} + +func getToken(w http.ResponseWriter, db *sql.DB, r *http.Request) { + claims, err := getClaims(r) + // Will verify existing signature and expiry time + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + UserClaims{ Id: claims.Id, Role: claims.Role, + Exp: time.Now().Add(time.Minute * 30).Format(time.UnixDate)}) + + tokenStr, err := token.SignedString([]byte(config["JWT_SECRET"])) + if err != nil { + log.Println("Token could not be signed: ", err, tokenStr) + http.Error(w, "Token generation error.", http.StatusInternalServerError) + return + } + + cookie := http.Cookie{Name: "hound", + Value: tokenStr, + Path: "/", + Expires: time.Now().Add(time.Hour * 24)} + http.SetCookie(w, &cookie) + _, err = w.Write([]byte(tokenStr)) + if err != nil { + http.Error(w, + "Could not complete token write.", + http.StatusInternalServerError)} } func validateEstimate() { return } -func route(w http.ResponseWriter, r *http.Request) { - var page Page - var args []string - p := r.URL.Path +func getClaims(r *http.Request) (UserClaims, error) { + claims := new(UserClaims) + _, tokenStr, found := strings.Cut(r.Header.Get("Authorization"), " ") - switch { - case r.Method == "GET" && match(p, "/", &args): - page = pages[ "home" ] - case match(p, "/terms", &args): - page = pages[ "terms" ] - case match(p, "/app", &args): - page = pages[ "app" ] - case match(p, "/assets", &args): - page = pages[ "app" ] - default: - http.NotFound(w, r) - return - } + if !found { + return *claims, errors.New("Token not found") + } - page.Render(w) + // Pull token payload into UserClaims + _, err := jwt.ParseWithClaims(tokenStr, claims, + func(token *jwt.Token) (any, error) { + return []byte(config["JWT_SECRET"]), nil + }) + + if err != nil { + return *claims, err + } + + if err = claims.Valid(); err != nil { + return *claims, err + } + + return *claims, nil +} + +func guard(r *http.Request, required int) bool { + claims, err := getClaims(r) + if err != nil { return false } + if roles[claims.Role] < required { return false } + + return true } func api(w http.ResponseWriter, r *http.Request) { var args []string - // var response string p := r.URL.Path db, err := sql.Open("mysql", @@ -534,54 +637,114 @@ func api(w http.ResponseWriter, r *http.Request) { err = db.Ping() if err != nil { - print("Bad database configuration: %v", err) + fmt.Println("Bad database configuration: %v\n", err) panic(err) // maybe os.Exit(1) instead } - switch { - case match(p, "/api/loans", &args): - resp, err := getLoanType(db, 0, 0, true) - - if resp != nil { - json.NewEncoder(w).Encode(resp) - } else { - json.NewEncoder(w).Encode(err) - } - - case match(p, "/api/fees", &args): - resp, err := getFeesTemp(db, 0) - - if resp != nil { - json.NewEncoder(w).Encode(resp) - } else { - json.NewEncoder(w).Encode(err) - } + case match(p, "/api/login", &args) && + r.Method == http.MethodPost: + login(w, db, r) + case match(p, "/api/token", &args) && + r.Method == http.MethodGet && guard(r, 1): + getToken(w, db, r) + case match(p, "/api/users", &args) && // Array of all users + r.Method == http.MethodGet && guard(r, 2): + getUsers(w, db, r) + case match(p, "/api/user", &args) && + r.Method == http.MethodGet, && guard(r, 1): + getUser(w, db, r) + case match(p, "/api/user", &args) && + r.Method == http.MethodPost && + guard(r, 3): + createUser(w, db, r) + case match(p, "/api/user", &args) && + r.Method == http.MethodPatch && + guard(r, 3): // For admin to modify any user + patchUser(w, db, r) + case match(p, "/api/user", &args) && + r.Method == http.MethodPatch && + guard(r, 2): // For employees to modify own accounts + patchSelf(w, db, r) + case match(p, "/api/user", &args) && + r.Method == http.MethodDelete && + guard(r, 3): + deleteUser(w, db, r) + case match(p, "/api/batch", &args) && + r.Method == http.MethodGet && + guard(r, 1): + getBatch(w, db, r) + case match(p, "/api/batch", &args) && + r.Method == http.MethodPost && + guard(r, 2): + openBatch(w, db, r) + case match(p, "/api/batch", &args) && + r.Method == http.MethodPatch && + guard(r, 2): + closeBatch(w, db, r) + case match(p, "/api/client", &args) && + r.Method == http.MethodGet && + guard(r, 1): + getClient(w, db, r) + case match(p, "/api/client", &args) && + r.Method == http.MethodPost && + guard(r, 2): + createClient(w, db, r) + case match(p, "/api/client", &args) && + r.Method == http.MethodPatch && + guard(r, 2): + patchClient(w, db, r) + case match(p, "/api/client", &args) && + r.Method == http.MethodDelete && + guard(r, 2): + deleteClient(w, db, r) + case match(p, "/api/ticket", &args) && + r.Method == http.MethodPost && + guard(r, 2): + openTicket(w, db, r) + case match(p, "/api/ticket", &args) && + r.Method == http.MethodPatch && + guard(r, 2): + closeTicket(w, db, r) + case match(p, "/api/ticket", &args) && + r.Method == http.MethodDelete && + guard(r, 2): + voidTicket(w, db, r) + case match(p, "/api/report/batch", &args) && + r.Method == http.MethodGet && + guard(r, 2): + reportBatch(w, db, r) + case match(p, "/api/report/summary", &args) && + r.Method == http.MethodPost && + guard(r, 2): + reportSummary(w, db, r) + } - case match(p, "/api/mi", &args): - var err error - est, err := getEstimate(db, 1) - if err != nil { - json.NewEncoder(w).Encode(err) - log.Println("error occured:", err) - break - } + db.Close() +} - json.NewEncoder(w).Encode(fetchMi(db, &est, 0).Loans[0].Mi) +func route(w http.ResponseWriter, r *http.Request) { + var page Page + var args []string + p := r.URL.Path - // if err != nil { - // json.NewEncoder(w).Encode(err) - // break - // } else { - // json.NewEncoder(w).Encode(resp) - // } - } + switch { + case r.Method == "GET" && match(p, "/", &args): + page = pages[ "home" ] + case match(p, "/terms", &args): + page = pages[ "terms" ] + case match(p, "/app", &args): + page = pages[ "app" ] + default: + http.NotFound(w, r) + return + } + page.Render(w) } func main() { files := http.FileServer(http.Dir("")) - http.Handle("/assets/", files) http.HandleFunc("/api/", api)