LuggageTracker/db.go
2025-07-14 22:08:10 -04:00

274 lines
6.2 KiB
Go

package main
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"github.com/go-sql-driver/mysql"
"io"
"log"
"os"
"strconv"
"time"
)
type LDB struct {
db *sql.DB
}
var NoEntriesFoundError = errors.New("no entries found")
var CodeExistsError = errors.New("code already exists")
func createTLSConf(skipVerify bool) tls.Config {
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(os.Getenv("DB_CA_CERT_PATH"))
if err != nil {
log.Fatal(err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Fatal("Failed to append PEM.")
}
clientCert := make([]tls.Certificate, 0, 1)
keyPair, err := tls.LoadX509KeyPair(os.Getenv("DB_CERT_PATH"), os.Getenv("DB_KEY_PATH"))
if err != nil {
log.Fatal(err)
}
clientCert = append(clientCert, keyPair)
return tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
InsecureSkipVerify: skipVerify,
}
}
func connect() *LDB {
var db *sql.DB
caCert := os.Getenv("DB_CA_CERT_PATH")
if len(caCert) > 0 {
skipVerify, err := strconv.ParseBool(os.Getenv("DB_SKIP_VERIFY"))
if err != nil {
log.Fatal(err)
}
tlsConf := createTLSConf(skipVerify)
err = mysql.RegisterTLSConfig("custom", &tlsConf)
if err != nil {
log.Fatalf("Error %s on RegisterTLSConfig\n", err)
return nil
}
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?tls=custom&parseTime=true&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASS"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_SCHEMA"),
)
db, err = sql.Open("mysql", connStr)
if err != nil {
log.Fatalf("Error connecting to db: %s", err)
return nil
}
} else {
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASS"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_SCHEMA"),
)
var err error
db, err = sql.Open("mysql", connStr)
if err != nil {
log.Fatal(err)
}
}
return &LDB{db}
}
func (ldb *LDB) queryUser(userName string) (*User, error) {
db := ldb.db
queryStr := `
SELECT users.user_name, users.contact_name, users.secret_codes, users.current_token, users.contact_number, users.email_address,
addresses.street1, addresses.street2, addresses.city, addresses.state, addresses.postal_code,
addresses.country, addresses.created_at, addresses.updated_at
FROM users
INNER JOIN addresses ON users.mailing_address=addresses.id
WHERE users.user_name=?;
`
stmt, err := db.Prepare(queryStr)
if err != nil {
return nil, err
}
res, err := stmt.Query(userName)
if err != nil {
return nil, err
}
defer res.Close()
if !res.Next() {
return nil, NoEntriesFoundError
}
user := new(User)
err = res.Scan(&user.UserName, &user.ContactName, &user.SecretCodes, &user.CurrentToken, &user.ContactNumber, &user.EmailAddress,
&user.MailingAddress.Street1, &user.MailingAddress.Street2, &user.MailingAddress.City,
&user.MailingAddress.State, &user.MailingAddress.PostalCode, &user.MailingAddress.Country,
&user.MailingAddress.CreatedAt, &user.MailingAddress.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
func RandNum(length int) string {
b := make([]byte, length)
n, err := io.ReadAtLeast(rand.Reader, b, length)
if n != length {
panic(err)
}
for i := 0; i < len(b); i++ {
b[i] = table[int(b[i])%len(table)]
}
return string(b)
}
var table = [...]byte{'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}
func (ldb *LDB) createUser(user *User) (*string, error) {
db := ldb.db
insertStr := `
INSERT INTO addresses(street1, street2, city, state, postal_code, country, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
SELECT last_insert_id();
`
stmt, err := db.Prepare(insertStr)
if err != nil {
return nil, err
}
row := stmt.QueryRow(
user.MailingAddress.Street1, user.MailingAddress.Street2, user.MailingAddress.City, user.MailingAddress.State,
user.MailingAddress.PostalCode, user.MailingAddress.Country, time.Now(), time.Now(),
)
var id int64
err = row.Scan(&id)
if err != nil {
return nil, err
}
insertStr = `
INSERT INTO users(user_name, contact_name, secret_codes, contact_number, email_address, mailing_address)
VALUES (?, ?, ?, ?, ?, ?);
`
stmt, err = db.Prepare(insertStr)
if err != nil {
return nil, err
}
code := RandNum(6)
_, err = stmt.Exec(user.UserName, user.ContactName, code, user.ContactName, user.EmailAddress, user.MailingAddress)
if err != nil {
return nil, err
}
return &code, nil
}
func (ldb *LDB) getUsers() ([]string, error) {
db := ldb.db
queryStr := `
SELECT user_name FROM users;
`
rows, err := db.Query(queryStr)
if err != nil {
return nil, err
}
var users []string
var user string
for rows.Next() {
err = rows.Scan(&user)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
func (ldb *LDB) updateToken(userName string, token string) error {
db := ldb.db
updateStr := `
UPDATE users
SET current_token=?,token_creation=?
WHERE user_name=?;
`
stmt, err := db.Prepare(updateStr)
if err != nil {
return err
}
_, err = stmt.Exec(token, time.Now(), userName)
if err != nil {
return err
}
return nil
}
// Creates an 8-digit token with GenerateToken, then inserts it into the db
// If the generated token already exists it will be regenerated up to 10 times
// Returns the token and any errors
func (ldb *LDB) createRegistrationCode() (*string, error) {
db := ldb.db
code, err := GenerateToken(8)
if err != nil {
return nil, err
}
for i := 0; i < 10; i++ {
_, err = db.Exec(
"INSERT INTO registration_codes (registration_code, expiration) VALUES (?, ?)",
code, time.Now(),
)
if err == nil {
return &code, nil
}
var mySqlErr *mysql.MySQLError
if errors.As(err, &mySqlErr) && mySqlErr.Number == 1062 {
continue
}
return nil, err
}
return nil, errors.New("could not generate a unique token after 10 retries")
}
// Queries DB for any matching code
// Returns no error if code is found
// Otherwise returns NoEntryFoundError
func (ldb *LDB) checkRegistrationCode(code string) error {
db := ldb.db
querySql := `
SELECT * FROM registration_codes
WHERE registration_code=?;
`
rows, err := db.Query(querySql, code)
if err != nil {
return err
}
if rows.Next() {
return nil
}
return NoEntriesFoundError
}