274 lines
6.2 KiB
Go
274 lines
6.2 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-sql-driver/mysql"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type LDB struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
var NoEntriesFoundError = errors.New("no entries found")
|
|
var CodeExistsError = errors.New("code already exists")
|
|
|
|
func createTLSConf(skipVerify bool) tls.Config {
|
|
rootCertPool := x509.NewCertPool()
|
|
pem, err := os.ReadFile(os.Getenv("DB_CA_CERT_PATH"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
log.Fatal("Failed to append PEM.")
|
|
}
|
|
clientCert := make([]tls.Certificate, 0, 1)
|
|
|
|
keyPair, err := tls.LoadX509KeyPair(os.Getenv("DB_CERT_PATH"), os.Getenv("DB_KEY_PATH"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
clientCert = append(clientCert, keyPair)
|
|
|
|
return tls.Config{
|
|
RootCAs: rootCertPool,
|
|
Certificates: clientCert,
|
|
InsecureSkipVerify: skipVerify,
|
|
}
|
|
}
|
|
|
|
func connect() *LDB {
|
|
var db *sql.DB
|
|
caCert := os.Getenv("DB_CA_CERT_PATH")
|
|
|
|
if len(caCert) > 0 {
|
|
skipVerify, err := strconv.ParseBool(os.Getenv("DB_SKIP_VERIFY"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
tlsConf := createTLSConf(skipVerify)
|
|
err = mysql.RegisterTLSConfig("custom", &tlsConf)
|
|
if err != nil {
|
|
log.Fatalf("Error %s on RegisterTLSConfig\n", err)
|
|
return nil
|
|
}
|
|
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?tls=custom&parseTime=true&loc=Local",
|
|
os.Getenv("DB_USER"),
|
|
os.Getenv("DB_PASS"),
|
|
os.Getenv("DB_HOST"),
|
|
os.Getenv("DB_PORT"),
|
|
os.Getenv("DB_SCHEMA"),
|
|
)
|
|
|
|
db, err = sql.Open("mysql", connStr)
|
|
if err != nil {
|
|
log.Fatalf("Error connecting to db: %s", err)
|
|
return nil
|
|
}
|
|
} else {
|
|
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&loc=Local",
|
|
os.Getenv("DB_USER"),
|
|
os.Getenv("DB_PASS"),
|
|
os.Getenv("DB_HOST"),
|
|
os.Getenv("DB_PORT"),
|
|
os.Getenv("DB_SCHEMA"),
|
|
)
|
|
var err error
|
|
db, err = sql.Open("mysql", connStr)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
return &LDB{db}
|
|
}
|
|
|
|
func (ldb *LDB) queryUser(userName string) (*User, error) {
|
|
db := ldb.db
|
|
|
|
queryStr := `
|
|
SELECT users.user_name, users.contact_name, users.secret_codes, users.current_token, users.contact_number, users.email_address,
|
|
addresses.street1, addresses.street2, addresses.city, addresses.state, addresses.postal_code,
|
|
addresses.country, addresses.created_at, addresses.updated_at
|
|
FROM users
|
|
INNER JOIN addresses ON users.mailing_address=addresses.id
|
|
WHERE users.user_name=?;
|
|
`
|
|
stmt, err := db.Prepare(queryStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res, err := stmt.Query(userName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Close()
|
|
if !res.Next() {
|
|
return nil, NoEntriesFoundError
|
|
}
|
|
user := new(User)
|
|
err = res.Scan(&user.UserName, &user.ContactName, &user.SecretCodes, &user.CurrentToken, &user.ContactNumber, &user.EmailAddress,
|
|
&user.MailingAddress.Street1, &user.MailingAddress.Street2, &user.MailingAddress.City,
|
|
&user.MailingAddress.State, &user.MailingAddress.PostalCode, &user.MailingAddress.Country,
|
|
&user.MailingAddress.CreatedAt, &user.MailingAddress.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func RandNum(length int) string {
|
|
b := make([]byte, length)
|
|
n, err := io.ReadAtLeast(rand.Reader, b, length)
|
|
if n != length {
|
|
panic(err)
|
|
}
|
|
for i := 0; i < len(b); i++ {
|
|
b[i] = table[int(b[i])%len(table)]
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
var table = [...]byte{'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}
|
|
|
|
func (ldb *LDB) createUser(user *User) (*string, error) {
|
|
db := ldb.db
|
|
|
|
insertStr := `
|
|
INSERT INTO addresses(street1, street2, city, state, postal_code, country, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
|
SELECT last_insert_id();
|
|
`
|
|
stmt, err := db.Prepare(insertStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
row := stmt.QueryRow(
|
|
user.MailingAddress.Street1, user.MailingAddress.Street2, user.MailingAddress.City, user.MailingAddress.State,
|
|
user.MailingAddress.PostalCode, user.MailingAddress.Country, time.Now(), time.Now(),
|
|
)
|
|
var id int64
|
|
err = row.Scan(&id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
insertStr = `
|
|
INSERT INTO users(user_name, contact_name, secret_codes, contact_number, email_address, mailing_address)
|
|
VALUES (?, ?, ?, ?, ?, ?);
|
|
`
|
|
stmt, err = db.Prepare(insertStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
code := RandNum(6)
|
|
_, err = stmt.Exec(user.UserName, user.ContactName, code, user.ContactName, user.EmailAddress, user.MailingAddress)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &code, nil
|
|
}
|
|
|
|
func (ldb *LDB) getUsers() ([]string, error) {
|
|
db := ldb.db
|
|
|
|
queryStr := `
|
|
SELECT user_name FROM users;
|
|
`
|
|
rows, err := db.Query(queryStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var users []string
|
|
var user string
|
|
for rows.Next() {
|
|
err = rows.Scan(&user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, user)
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (ldb *LDB) updateToken(userName string, token string) error {
|
|
db := ldb.db
|
|
|
|
updateStr := `
|
|
UPDATE users
|
|
SET current_token=?,token_creation=?
|
|
WHERE user_name=?;
|
|
`
|
|
stmt, err := db.Prepare(updateStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = stmt.Exec(token, time.Now(), userName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Creates an 8-digit token with GenerateToken, then inserts it into the db
|
|
// If the generated token already exists it will be regenerated up to 10 times
|
|
// Returns the token and any errors
|
|
func (ldb *LDB) createRegistrationCode() (*string, error) {
|
|
db := ldb.db
|
|
code, err := GenerateToken(8)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
_, err = db.Exec(
|
|
"INSERT INTO registration_codes (registration_code, expiration) VALUES (?, ?)",
|
|
code, time.Now(),
|
|
)
|
|
if err == nil {
|
|
return &code, nil
|
|
}
|
|
|
|
var mySqlErr *mysql.MySQLError
|
|
if errors.As(err, &mySqlErr) && mySqlErr.Number == 1062 {
|
|
continue
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return nil, errors.New("could not generate a unique token after 10 retries")
|
|
}
|
|
|
|
// Queries DB for any matching code
|
|
// Returns no error if code is found
|
|
// Otherwise returns NoEntryFoundError
|
|
func (ldb *LDB) checkRegistrationCode(code string) error {
|
|
db := ldb.db
|
|
|
|
querySql := `
|
|
SELECT * FROM registration_codes
|
|
WHERE registration_code=?;
|
|
`
|
|
rows, err := db.Query(querySql, code)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows.Next() {
|
|
return nil
|
|
}
|
|
return NoEntriesFoundError
|
|
}
|