146 lines
3.3 KiB
Go
146 lines
3.3 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-sql-driver/mysql"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type LDB struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
var NoEntriesFoundError error = errors.New("no entries found")
|
|
|
|
func createTLSConf(skipVerify bool) tls.Config {
|
|
rootCertPool := x509.NewCertPool()
|
|
pem, err := os.ReadFile(os.Getenv("DB_CA_CERT_PATH"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
log.Fatal("Failed to append PEM.")
|
|
}
|
|
clientCert := make([]tls.Certificate, 0, 1)
|
|
|
|
keyPair, err := tls.LoadX509KeyPair(os.Getenv("DB_CERT_PATH"), os.Getenv("DB_KEY_PATH"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
clientCert = append(clientCert, keyPair)
|
|
|
|
return tls.Config{
|
|
RootCAs: rootCertPool,
|
|
Certificates: clientCert,
|
|
InsecureSkipVerify: skipVerify,
|
|
}
|
|
}
|
|
|
|
func connect() *LDB {
|
|
var db *sql.DB
|
|
caCert := os.Getenv("DB_CA_CERT_PATH")
|
|
|
|
if len(caCert) > 0 {
|
|
skipVerify, err := strconv.ParseBool(os.Getenv("DB_SKIP_VERIFY"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
tlsConf := createTLSConf(skipVerify)
|
|
err = mysql.RegisterTLSConfig("custom", &tlsConf)
|
|
if err != nil {
|
|
log.Fatalf("Error %s on RegisterTLSConfig\n", err)
|
|
return nil
|
|
}
|
|
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?tls=custom&parseTime=true&loc=Local",
|
|
os.Getenv("DB_USER"),
|
|
os.Getenv("DB_PASS"),
|
|
os.Getenv("DB_HOST"),
|
|
os.Getenv("DB_PORT"),
|
|
os.Getenv("DB_SCHEMA"),
|
|
)
|
|
|
|
db, err = sql.Open("mysql", connStr)
|
|
if err != nil {
|
|
log.Fatalf("Error connecting to db: %s", err)
|
|
return nil
|
|
}
|
|
} else {
|
|
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&loc=Local",
|
|
os.Getenv("DB_USER"),
|
|
os.Getenv("DB_PASS"),
|
|
os.Getenv("DB_HOST"),
|
|
os.Getenv("DB_PORT"),
|
|
os.Getenv("DB_SCHEMA"),
|
|
)
|
|
var err error
|
|
db, err = sql.Open("mysql", connStr)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
return &LDB{db}
|
|
}
|
|
|
|
func (ldb *LDB) queryUser(userName string) (u *User, newError error) {
|
|
db := ldb.db
|
|
|
|
queryStr := `
|
|
SELECT users.user_name, users.secret_codes, users.current_token, users.contact_number, users.email_address,
|
|
addresses.street1, addresses.street2, addresses.city, addresses.state, addresses.postal_code,
|
|
addresses.country, addresses.created_at, addresses.updated_at
|
|
FROM users
|
|
INNER JOIN addresses ON users.mailing_address=addresses.id
|
|
WHERE users.user_name=?;
|
|
`
|
|
stmt, err := db.Prepare(queryStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res, err := stmt.Query(userName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Close()
|
|
if !res.Next() {
|
|
return nil, NoEntriesFoundError
|
|
}
|
|
user := new(User)
|
|
err = res.Scan(&user.UserName, &user.SecretCodes, &user.CurrentToken, &user.ContactNumber, &user.EmailAddress,
|
|
&user.MailingAddress.Street1, &user.MailingAddress.Street2, &user.MailingAddress.City,
|
|
&user.MailingAddress.State, &user.MailingAddress.PostalCode, &user.MailingAddress.Country,
|
|
&user.MailingAddress.CreatedAt, &user.MailingAddress.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (ldb *LDB) updateToken(userName string, token string) error {
|
|
db := ldb.db
|
|
|
|
updateStr := `
|
|
UPDATE users
|
|
SET current_token=?,token_creation=?
|
|
WHERE user_name=?;
|
|
`
|
|
stmt, err := db.Prepare(updateStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = stmt.Exec(token, time.Now(), userName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|