Refactor global variables into injected dependencies
This commit is contained in:
parent
e5cd6e0be3
commit
4a79379d8e
10 changed files with 366 additions and 182 deletions
|
@ -8,7 +8,8 @@ import (
|
||||||
"git.klink.asia/paul/certman/models"
|
"git.klink.asia/paul/certman/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterHandler(w http.ResponseWriter, req *http.Request) {
|
func RegisterHandler(p *services.Provider) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
// Get parameters
|
// Get parameters
|
||||||
email := req.Form.Get("email")
|
email := req.Form.Get("email")
|
||||||
password := req.Form.Get("password")
|
password := req.Form.Get("password")
|
||||||
|
@ -17,12 +18,12 @@ func RegisterHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
user.Email = email
|
user.Email = email
|
||||||
user.SetPassword(password)
|
user.SetPassword(password)
|
||||||
|
|
||||||
err := services.Database.Create(&user).Error
|
err := p.DB.Create(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error)
|
panic(err.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
services.SessionStore.Flash(w, req,
|
p.Sessions.Flash(w, req,
|
||||||
services.Flash{
|
services.Flash{
|
||||||
Type: "success",
|
Type: "success",
|
||||||
Message: "The user was created. Check your inbox for the confirmation email.",
|
Message: "The user was created. Check your inbox for the confirmation email.",
|
||||||
|
@ -31,19 +32,21 @@ func RegisterHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
http.Redirect(w, req, "/login", http.StatusFound)
|
http.Redirect(w, req, "/login", http.StatusFound)
|
||||||
return
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoginHandler(w http.ResponseWriter, req *http.Request) {
|
func LoginHandler(p *services.Provider) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
// Get parameters
|
// Get parameters
|
||||||
email := req.Form.Get("email")
|
email := req.Form.Get("email")
|
||||||
password := req.Form.Get("password")
|
password := req.Form.Get("password")
|
||||||
|
|
||||||
user := models.User{}
|
user := models.User{}
|
||||||
|
|
||||||
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error
|
err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// could not find user
|
// could not find user
|
||||||
services.SessionStore.Flash(
|
p.Sessions.Flash(
|
||||||
w, req, services.Flash{
|
w, req, services.Flash{
|
||||||
Type: "warning", Message: "Invalid Email or Password.",
|
Type: "warning", Message: "Invalid Email or Password.",
|
||||||
},
|
},
|
||||||
|
@ -53,7 +56,7 @@ func LoginHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.EmailValid {
|
if !user.EmailValid {
|
||||||
services.SessionStore.Flash(
|
p.Sessions.Flash(
|
||||||
w, req, services.Flash{
|
w, req, services.Flash{
|
||||||
Type: "warning", Message: "You need to confirm your email before logging in.",
|
Type: "warning", Message: "You need to confirm your email before logging in.",
|
||||||
},
|
},
|
||||||
|
@ -64,7 +67,7 @@ func LoginHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
if err := user.CheckPassword(password); err != nil {
|
if err := user.CheckPassword(password); err != nil {
|
||||||
// wrong password
|
// wrong password
|
||||||
services.SessionStore.Flash(
|
p.Sessions.Flash(
|
||||||
w, req, services.Flash{
|
w, req, services.Flash{
|
||||||
Type: "warning", Message: "Invalid Email or Password.",
|
Type: "warning", Message: "Invalid Email or Password.",
|
||||||
},
|
},
|
||||||
|
@ -74,7 +77,8 @@ func LoginHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// user is logged in, set cookie
|
// user is logged in, set cookie
|
||||||
services.SessionStore.SetUserEmail(w, req, email)
|
p.Sessions.SetUserEmail(w, req, email)
|
||||||
|
|
||||||
http.Redirect(w, req, "/certs", http.StatusSeeOther)
|
http.Redirect(w, req, "/certs", http.StatusSeeOther)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,17 +19,20 @@ import (
|
||||||
"git.klink.asia/paul/certman/views"
|
"git.klink.asia/paul/certman/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ListCertHandler(w http.ResponseWriter, req *http.Request) {
|
func ListCertHandler(p *services.Provider) http.HandlerFunc {
|
||||||
v := views.New(req)
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
v := views.NewWithSession(req, p.Sessions)
|
||||||
v.Render(w, "cert_list")
|
v.Render(w, "cert_list")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateCertHandler(w http.ResponseWriter, req *http.Request) {
|
func CreateCertHandler(p *services.Provider) http.HandlerFunc {
|
||||||
email := services.SessionStore.GetUserEmail(req)
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
email := p.Sessions.GetUserEmail(req)
|
||||||
certname := req.FormValue("certname")
|
certname := req.FormValue("certname")
|
||||||
|
|
||||||
user := models.User{}
|
user := models.User{}
|
||||||
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error
|
err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Could not fetch user for mail %s\n", email)
|
fmt.Printf("Could not fetch user for mail %s\n", email)
|
||||||
}
|
}
|
||||||
|
@ -59,11 +62,11 @@ func CreateCertHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert client into database
|
// Insert client into database
|
||||||
if err := services.Database.Create(&client).Error; err != nil {
|
if err := p.DB.Create(&client).Error; err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
services.SessionStore.Flash(w, req,
|
p.Sessions.Flash(w, req,
|
||||||
services.Flash{
|
services.Flash{
|
||||||
Type: "success",
|
Type: "success",
|
||||||
Message: "The certificate was created successfully.",
|
Message: "The certificate was created successfully.",
|
||||||
|
@ -71,9 +74,11 @@ func CreateCertHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
)
|
)
|
||||||
|
|
||||||
http.Redirect(w, req, "/certs", http.StatusFound)
|
http.Redirect(w, req, "/certs", http.StatusFound)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownloadCertHandler(w http.ResponseWriter, req *http.Request) {
|
func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
//v := views.New(req)
|
//v := views.New(req)
|
||||||
//
|
//
|
||||||
//derBytes, err := CreateCertificate(key, caCert, caKey)
|
//derBytes, err := CreateCertificate(key, caCert, caKey)
|
||||||
|
@ -82,6 +87,7 @@ func DownloadCertHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
//pkBytes := x509.MarshalPKCS1PrivateKey(key)
|
//pkBytes := x509.MarshalPKCS1PrivateKey(key)
|
||||||
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
|
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
|
||||||
return
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
||||||
|
|
34
main.go
34
main.go
|
@ -3,22 +3,46 @@ package main
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
|
|
||||||
"git.klink.asia/paul/certman/services"
|
"git.klink.asia/paul/certman/services"
|
||||||
|
|
||||||
"git.klink.asia/paul/certman/router"
|
"git.klink.asia/paul/certman/router"
|
||||||
"git.klink.asia/paul/certman/views"
|
"git.klink.asia/paul/certman/views"
|
||||||
|
|
||||||
// import sqlite3 driver once
|
// import sqlite3 driver
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
c := services.Config{
|
||||||
|
DB: &services.DBConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
DSN: "db.sqlite3",
|
||||||
|
Log: true,
|
||||||
|
},
|
||||||
|
Sessions: &services.SessionsConfig{
|
||||||
|
SessionName: "_session",
|
||||||
|
CookieKey: string(securecookie.GenerateRandomKey(32)),
|
||||||
|
HttpOnly: true,
|
||||||
|
Lifetime: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
Email: &services.EmailConfig{
|
||||||
|
SMTPServer: "example.com",
|
||||||
|
SMTPPort: 25,
|
||||||
|
SMTPUsername: "test",
|
||||||
|
SMTPPassword: "test",
|
||||||
|
From: "Mailtest <test@example.com>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Connect to the database
|
serviceProvider := services.NewProvider(&c)
|
||||||
db := services.InitDB()
|
|
||||||
|
|
||||||
services.InitSession()
|
// Start the mail daemon, which re-uses connections to send mails to the
|
||||||
|
// SMTP server
|
||||||
|
go serviceProvider.Email.Daemon()
|
||||||
|
|
||||||
//user := models.User{}
|
//user := models.User{}
|
||||||
//user.Username = "test"
|
//user.Username = "test"
|
||||||
|
@ -29,7 +53,7 @@ func main() {
|
||||||
// load and parse template files
|
// load and parse template files
|
||||||
views.LoadTemplates()
|
views.LoadTemplates()
|
||||||
|
|
||||||
mux := router.HandleRoutes(db)
|
mux := router.HandleRoutes(serviceProvider)
|
||||||
|
|
||||||
err := http.ListenAndServe(":8000", mux)
|
err := http.ListenAndServe(":8000", mux)
|
||||||
log.Fatalf(err.Error())
|
log.Fatalf(err.Error())
|
||||||
|
|
|
@ -8,13 +8,15 @@ import (
|
||||||
|
|
||||||
// RequireLogin is a middleware that checks for a username in the active
|
// RequireLogin is a middleware that checks for a username in the active
|
||||||
// session, and redirects to `/login` if no username was found.
|
// session, and redirects to `/login` if no username was found.
|
||||||
func RequireLogin(next http.Handler) http.Handler {
|
func RequireLogin(sessions *services.Sessions) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
fn := func(w http.ResponseWriter, req *http.Request) {
|
fn := func(w http.ResponseWriter, req *http.Request) {
|
||||||
if username := services.SessionStore.GetUserEmail(req); username == "" {
|
if username := sessions.GetUserEmail(req); username == "" {
|
||||||
http.Redirect(w, req, "/login", http.StatusFound)
|
http.Redirect(w, req, "/login", http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, req)
|
next.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
return http.HandlerFunc(fn)
|
return http.HandlerFunc(fn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/chi/middleware"
|
"github.com/go-chi/chi/middleware"
|
||||||
"github.com/gorilla/csrf"
|
"github.com/gorilla/csrf"
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
|
|
||||||
mw "git.klink.asia/paul/certman/middleware"
|
mw "git.klink.asia/paul/certman/middleware"
|
||||||
)
|
)
|
||||||
|
@ -26,7 +25,7 @@ var (
|
||||||
cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH")
|
cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH")
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleRoutes(db *gorm.DB) http.Handler {
|
func HandleRoutes(provider *services.Provider) http.Handler {
|
||||||
mux := chi.NewMux()
|
mux := chi.NewMux()
|
||||||
|
|
||||||
//mux.Use(middleware.RequestID)
|
//mux.Use(middleware.RequestID)
|
||||||
|
@ -34,7 +33,7 @@ func HandleRoutes(db *gorm.DB) http.Handler {
|
||||||
mux.Use(middleware.RealIP) // use proxy headers
|
mux.Use(middleware.RealIP) // use proxy headers
|
||||||
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
|
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
|
||||||
mux.Use(mw.Recoverer) // recover on panic
|
mux.Use(mw.Recoverer) // recover on panic
|
||||||
mux.Use(services.SessionStore.Use) // use session storage
|
mux.Use(provider.Sessions.Manager.Use) // use session storage
|
||||||
|
|
||||||
// we are serving the static files directly from the assets package
|
// we are serving the static files directly from the assets package
|
||||||
// this either means we use the embedded files, or live-load
|
// this either means we use the embedded files, or live-load
|
||||||
|
@ -56,26 +55,26 @@ func HandleRoutes(db *gorm.DB) http.Handler {
|
||||||
|
|
||||||
r.Route("/register", func(r chi.Router) {
|
r.Route("/register", func(r chi.Router) {
|
||||||
r.Get("/", v("register"))
|
r.Get("/", v("register"))
|
||||||
r.Post("/", handlers.RegisterHandler)
|
r.Post("/", handlers.RegisterHandler(provider))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Route("/login", func(r chi.Router) {
|
r.Route("/login", func(r chi.Router) {
|
||||||
r.Get("/", v("login"))
|
r.Get("/", v("login"))
|
||||||
r.Post("/", handlers.LoginHandler)
|
r.Post("/", handlers.LoginHandler(provider))
|
||||||
})
|
})
|
||||||
|
|
||||||
//r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db))
|
//r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db))
|
||||||
|
|
||||||
r.Route("/forgot-password", func(r chi.Router) {
|
r.Route("/forgot-password", func(r chi.Router) {
|
||||||
r.Get("/", v("forgot-password"))
|
r.Get("/", v("forgot-password"))
|
||||||
r.Post("/", handlers.LoginHandler)
|
r.Post("/", handlers.LoginHandler(provider))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Route("/certs", func(r chi.Router) {
|
r.Route("/certs", func(r chi.Router) {
|
||||||
r.Use(mw.RequireLogin)
|
r.Use(mw.RequireLogin(provider.Sessions))
|
||||||
r.Get("/", handlers.ListCertHandler)
|
r.Get("/", handlers.ListCertHandler(provider))
|
||||||
r.Post("/new", handlers.CreateCertHandler)
|
r.Post("/new", handlers.CreateCertHandler(provider))
|
||||||
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler)
|
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler(provider))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) {
|
r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"git.klink.asia/paul/certman/models"
|
"git.klink.asia/paul/certman/models"
|
||||||
"git.klink.asia/paul/certman/settings"
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,28 +13,34 @@ var (
|
||||||
ErrNotImplemented = errors.New("Not implemented")
|
ErrNotImplemented = errors.New("Not implemented")
|
||||||
)
|
)
|
||||||
|
|
||||||
var Database *gorm.DB
|
type DBConfig struct {
|
||||||
|
Type string
|
||||||
|
DSN string
|
||||||
|
Log bool
|
||||||
|
}
|
||||||
|
|
||||||
// DB is a wrapper around gorm.DB to provide custom methods
|
// DB is a wrapper around gorm.DB to provide custom methods
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*gorm.DB
|
*gorm.DB
|
||||||
|
|
||||||
|
conf *DBConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitDB() *gorm.DB {
|
func NewDB(conf *DBConfig) *DB {
|
||||||
dsn := settings.Get("DATABASE_URL", "db.sqlite3")
|
|
||||||
|
|
||||||
// Establish connection
|
// Establish connection
|
||||||
db, err := gorm.Open("sqlite3", dsn)
|
db, err := gorm.Open(conf.Type, conf.DSN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Could not open database: %s", err.Error())
|
log.Fatalf("Could not open database: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate models
|
// Migrate models
|
||||||
db.AutoMigrate(models.User{}, models.Client{})
|
db.AutoMigrate(models.User{}, models.Client{})
|
||||||
db.LogMode(true)
|
db.LogMode(conf.Log)
|
||||||
|
|
||||||
Database = db
|
return &DB{
|
||||||
return db
|
DB: db,
|
||||||
|
conf: conf,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountUsers returns the number of Users in the datastore
|
// CountUsers returns the number of Users in the datastore
|
||||||
|
|
103
services/email.go
Normal file
103
services/email.go
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-mail/mail"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMailUninitializedConfig = errors.New("Mail: uninitialized config")
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmailConfig struct {
|
||||||
|
From string
|
||||||
|
SMTPServer string
|
||||||
|
SMTPPort int
|
||||||
|
SMTPUsername string
|
||||||
|
SMTPPassword string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Email struct {
|
||||||
|
config *EmailConfig
|
||||||
|
|
||||||
|
mailChan chan *mail.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmail(conf *EmailConfig) *Email {
|
||||||
|
if conf == nil {
|
||||||
|
log.Println(ErrMailUninitializedConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Email{
|
||||||
|
config: conf,
|
||||||
|
mailChan: make(chan *mail.Message, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends an email to the receiver
|
||||||
|
func (email *Email) Send(to, subject, text, html string) error {
|
||||||
|
if email.config == nil {
|
||||||
|
log.Print("Error: trying to send mail with uninitialized config.")
|
||||||
|
return ErrMailUninitializedConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
m := mail.NewMessage()
|
||||||
|
m.SetHeader("From", email.config.From)
|
||||||
|
m.SetHeader("To", to)
|
||||||
|
m.SetHeader("Subject", subject)
|
||||||
|
m.SetBody("text/plain", text)
|
||||||
|
m.AddAlternative("text/html", html)
|
||||||
|
|
||||||
|
// put email in chan
|
||||||
|
email.mailChan <- m
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Daemon is a function that takes Mail and sends it without blocking.
|
||||||
|
// WIP
|
||||||
|
func (email *Email) Daemon() {
|
||||||
|
if email.config == nil {
|
||||||
|
log.Print("Error: trying to set up mail deamon with uninitialized config.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d := mail.NewDialer(
|
||||||
|
email.config.SMTPServer,
|
||||||
|
email.config.SMTPPort,
|
||||||
|
email.config.SMTPUsername,
|
||||||
|
email.config.SMTPPassword)
|
||||||
|
|
||||||
|
var s mail.SendCloser
|
||||||
|
var err error
|
||||||
|
open := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case m, ok := <-email.mailChan:
|
||||||
|
if !ok {
|
||||||
|
// channel is closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !open {
|
||||||
|
if s, err = d.Dial(); err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
open = true
|
||||||
|
}
|
||||||
|
if err := mail.Send(s, m); err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
|
// Close the connection if no email was sent in the last 30 seconds.
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
if open {
|
||||||
|
if err := s.Close(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
open = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
24
services/provider.go
Normal file
24
services/provider.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
DB *DBConfig
|
||||||
|
Sessions *SessionsConfig
|
||||||
|
Email *EmailConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type Provider struct {
|
||||||
|
DB *DB
|
||||||
|
Sessions *Sessions
|
||||||
|
Email *Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider returns the ServiceProvider
|
||||||
|
func NewProvider(conf *Config) *Provider {
|
||||||
|
var provider = &Provider{}
|
||||||
|
|
||||||
|
provider.DB = NewDB(conf.DB)
|
||||||
|
provider.Sessions = NewSessions(conf.Sessions)
|
||||||
|
provider.Email = NewEmail(conf.Email)
|
||||||
|
|
||||||
|
return provider
|
||||||
|
}
|
|
@ -8,16 +8,10 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.klink.asia/paul/certman/settings"
|
|
||||||
"github.com/alexedwards/scs"
|
"github.com/alexedwards/scs"
|
||||||
"github.com/gorilla/securecookie"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// SessionName is the name of the session cookie
|
|
||||||
SessionName = "session"
|
|
||||||
// CookieKey is the key the cookies are encrypted and signed with
|
|
||||||
CookieKey = string(securecookie.GenerateRandomKey(32))
|
|
||||||
// FlashesKey is the key used for the flashes in the cookie
|
// FlashesKey is the key used for the flashes in the cookie
|
||||||
FlashesKey = "_flashes"
|
FlashesKey = "_flashes"
|
||||||
// UserEmailKey is the key used to reference usernames
|
// UserEmailKey is the key used to reference usernames
|
||||||
|
@ -29,30 +23,33 @@ func init() {
|
||||||
gob.Register(Flash{})
|
gob.Register(Flash{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionStore is a globally accessible sessions store for the application
|
type SessionsConfig struct {
|
||||||
var SessionStore *Store
|
SessionName string
|
||||||
|
CookieKey string
|
||||||
|
HttpOnly bool
|
||||||
|
Secure bool
|
||||||
|
Lifetime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
// Store is a wrapped scs.Store in order to implement custom
|
// Sessions is a wrapped scs.Store in order to implement custom logic
|
||||||
// logic
|
type Sessions struct {
|
||||||
type Store struct {
|
|
||||||
*scs.Manager
|
*scs.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitSession populates the default sessions Store
|
// NewSessions populates the default sessions Store
|
||||||
func InitSession() {
|
func NewSessions(conf *SessionsConfig) *Sessions {
|
||||||
store := scs.NewCookieManager(
|
store := scs.NewCookieManager(
|
||||||
CookieKey,
|
conf.CookieKey,
|
||||||
)
|
)
|
||||||
|
store.Name(conf.SessionName)
|
||||||
store.HttpOnly(true)
|
store.HttpOnly(true)
|
||||||
store.Lifetime(24 * time.Hour)
|
store.Lifetime(conf.Lifetime)
|
||||||
|
store.Secure(conf.Secure)
|
||||||
|
|
||||||
// Use secure cookies (HTTPS only) in production
|
return &Sessions{store}
|
||||||
store.Secure(settings.Get("ENVIRONMENT", "") == "production")
|
|
||||||
|
|
||||||
SessionStore = &Store{store}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Store) GetUserEmail(req *http.Request) string {
|
func (store *Sessions) GetUserEmail(req *http.Request) string {
|
||||||
if store == nil {
|
if store == nil {
|
||||||
// if store was not initialized, all requests fail
|
// if store was not initialized, all requests fail
|
||||||
log.Println("Zero pointer when checking session for username")
|
log.Println("Zero pointer when checking session for username")
|
||||||
|
@ -72,7 +69,7 @@ func (store *Store) GetUserEmail(req *http.Request) string {
|
||||||
return email
|
return email
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Store) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
|
func (store *Sessions) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
|
||||||
if store == nil {
|
if store == nil {
|
||||||
// if store was not initialized, do nothing
|
// if store was not initialized, do nothing
|
||||||
return
|
return
|
||||||
|
@ -103,7 +100,7 @@ func (flash Flash) Render() template.HTML {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flash add flash message to session data
|
// Flash add flash message to session data
|
||||||
func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
|
func (store *Sessions) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
|
||||||
var flashes []Flash
|
var flashes []Flash
|
||||||
|
|
||||||
sess := store.Load(req)
|
sess := store.Load(req)
|
||||||
|
@ -118,7 +115,7 @@ func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flashes returns a slice of flash messages from session data
|
// Flashes returns a slice of flash messages from session data
|
||||||
func (store *Store) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
|
func (store *Sessions) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
|
||||||
var flashes []Flash
|
var flashes []Flash
|
||||||
sess := store.Load(req)
|
sess := store.Load(req)
|
||||||
sess.PopObject(w, FlashesKey, &flashes)
|
sess.PopObject(w, FlashesKey, &flashes)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
type View struct {
|
type View struct {
|
||||||
Vars map[string]interface{}
|
Vars map[string]interface{}
|
||||||
Request *http.Request
|
Request *http.Request
|
||||||
|
SessionStore *services.Sessions
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(req *http.Request) *View {
|
func New(req *http.Request) *View {
|
||||||
|
@ -23,12 +24,29 @@ func New(req *http.Request) *View {
|
||||||
Vars: map[string]interface{}{
|
Vars: map[string]interface{}{
|
||||||
"CSRF_TOKEN": csrf.Token(req),
|
"CSRF_TOKEN": csrf.Token(req),
|
||||||
"csrfField": csrf.TemplateField(req),
|
"csrfField": csrf.TemplateField(req),
|
||||||
"username": services.SessionStore.GetUserEmail(req),
|
|
||||||
"Meta": map[string]interface{}{
|
"Meta": map[string]interface{}{
|
||||||
"Path": req.URL.Path,
|
"Path": req.URL.Path,
|
||||||
"Env": "develop",
|
"Env": "develop",
|
||||||
},
|
},
|
||||||
"flashes": []services.Flash{},
|
"flashes": []services.Flash{},
|
||||||
|
"username": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithSession(req *http.Request, sessionStore *services.Sessions) *View {
|
||||||
|
return &View{
|
||||||
|
Request: req,
|
||||||
|
SessionStore: sessionStore,
|
||||||
|
Vars: map[string]interface{}{
|
||||||
|
"CSRF_TOKEN": csrf.Token(req),
|
||||||
|
"csrfField": csrf.TemplateField(req),
|
||||||
|
"Meta": map[string]interface{}{
|
||||||
|
"Path": req.URL.Path,
|
||||||
|
"Env": "develop",
|
||||||
|
},
|
||||||
|
"flashes": []services.Flash{},
|
||||||
|
"username": sessionStore.GetUserEmail(req),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,8 +61,10 @@ func (view View) Render(w http.ResponseWriter, name string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if view.SessionStore != nil {
|
||||||
// add flashes to template
|
// add flashes to template
|
||||||
view.Vars["flashes"] = services.SessionStore.Flashes(w, view.Request)
|
view.Vars["flashes"] = view.SessionStore.Flashes(w, view.Request)
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
Loading…
Reference in a new issue