LuggageTracker/db.go
2025-07-10 09:03:57 -04:00

146 lines
3.4 KiB
Go

package main
import (
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"github.com/go-sql-driver/mysql"
"log"
"os"
"strconv"
"time"
)
type LDB struct {
db *sql.DB
}
var NoEntriesFoundError error = errors.New("no entries found")
func createTLSConf(skipVerify bool) tls.Config {
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(os.Getenv("DB_CA_CERT_PATH"))
if err != nil {
log.Fatal(err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Fatal("Failed to append PEM.")
}
clientCert := make([]tls.Certificate, 0, 1)
keyPair, err := tls.LoadX509KeyPair(os.Getenv("DB_CERT_PATH"), os.Getenv("DB_KEY_PATH"))
if err != nil {
log.Fatal(err)
}
clientCert = append(clientCert, keyPair)
return tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
InsecureSkipVerify: skipVerify,
}
}
func connect() *LDB {
var db *sql.DB
caCert := os.Getenv("DB_CA_CERT_PATH")
if len(caCert) > 0 {
skipVerify, err := strconv.ParseBool(os.Getenv("DB_SKIP_VERIFY"))
if err != nil {
log.Fatal(err)
}
tlsConf := createTLSConf(skipVerify)
err = mysql.RegisterTLSConfig("custom", &tlsConf)
if err != nil {
log.Fatalf("Error %s on RegisterTLSConfig\n", err)
return nil
}
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?tls=custom&parseTime=true&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASS"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_SCHEMA"),
)
db, err = sql.Open("mysql", connStr)
if err != nil {
log.Fatalf("Error connecting to db: %s", err)
return nil
}
} else {
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASS"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_SCHEMA"),
)
var err error
db, err = sql.Open("mysql", connStr)
if err != nil {
log.Fatal(err)
}
}
return &LDB{db}
}
func (ldb *LDB) queryUser(userName string) (u *User, newError error) {
db := ldb.db
queryStr := `
SELECT users.user_name, users.contact_name, users.secret_codes, users.current_token, users.contact_number, users.email_address,
addresses.street1, addresses.street2, addresses.city, addresses.state, addresses.postal_code,
addresses.country, addresses.created_at, addresses.updated_at
FROM users
INNER JOIN addresses ON users.mailing_address=addresses.id
WHERE users.user_name=?;
`
stmt, err := db.Prepare(queryStr)
if err != nil {
return nil, err
}
res, err := stmt.Query(userName)
if err != nil {
return nil, err
}
defer res.Close()
if !res.Next() {
return nil, NoEntriesFoundError
}
user := new(User)
err = res.Scan(&user.UserName, &user.ContactName, &user.SecretCodes, &user.CurrentToken, &user.ContactNumber, &user.EmailAddress,
&user.MailingAddress.Street1, &user.MailingAddress.Street2, &user.MailingAddress.City,
&user.MailingAddress.State, &user.MailingAddress.PostalCode, &user.MailingAddress.Country,
&user.MailingAddress.CreatedAt, &user.MailingAddress.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
func (ldb *LDB) updateToken(userName string, token string) error {
db := ldb.db
updateStr := `
UPDATE users
SET current_token=?,token_creation=?
WHERE user_name=?;
`
stmt, err := db.Prepare(updateStr)
if err != nil {
return err
}
_, err = stmt.Exec(token, time.Now(), userName)
if err != nil {
return err
}
return nil
}