package main import ( "crypto/tls" "crypto/x509" "database/sql" "errors" "fmt" "github.com/go-sql-driver/mysql" "log" "os" "strconv" "time" ) type LDB struct { db *sql.DB } var NoEntriesFoundError error = errors.New("no entries found") func createTLSConf(skipVerify bool) tls.Config { rootCertPool := x509.NewCertPool() pem, err := os.ReadFile(os.Getenv("DB_CA_CERT_PATH")) if err != nil { log.Fatal(err) } if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { log.Fatal("Failed to append PEM.") } clientCert := make([]tls.Certificate, 0, 1) keyPair, err := tls.LoadX509KeyPair(os.Getenv("DB_CERT_PATH"), os.Getenv("DB_KEY_PATH")) if err != nil { log.Fatal(err) } clientCert = append(clientCert, keyPair) return tls.Config{ RootCAs: rootCertPool, Certificates: clientCert, InsecureSkipVerify: skipVerify, } } func connect() *LDB { var db *sql.DB caCert := os.Getenv("DB_CA_CERT_PATH") if len(caCert) > 0 { skipVerify, err := strconv.ParseBool(os.Getenv("DB_SKIP_VERIFY")) if err != nil { log.Fatal(err) } tlsConf := createTLSConf(skipVerify) err = mysql.RegisterTLSConfig("custom", &tlsConf) if err != nil { log.Fatalf("Error %s on RegisterTLSConfig\n", err) return nil } connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?tls=custom&parseTime=true&loc=Local", os.Getenv("DB_USER"), os.Getenv("DB_PASS"), os.Getenv("DB_HOST"), os.Getenv("DB_PORT"), os.Getenv("DB_SCHEMA"), ) db, err = sql.Open("mysql", connStr) if err != nil { log.Fatalf("Error connecting to db: %s", err) return nil } } else { connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&loc=Local", os.Getenv("DB_USER"), os.Getenv("DB_PASS"), os.Getenv("DB_HOST"), os.Getenv("DB_PORT"), os.Getenv("DB_SCHEMA"), ) var err error db, err = sql.Open("mysql", connStr) if err != nil { log.Fatal(err) } } return &LDB{db} } func (ldb *LDB) queryUser(userName string) (u *User, newError error) { db := ldb.db queryStr := ` SELECT users.user_name, users.contact_name, users.secret_codes, users.current_token, users.contact_number, users.email_address, addresses.street1, addresses.street2, addresses.city, addresses.state, addresses.postal_code, addresses.country, addresses.created_at, addresses.updated_at FROM users INNER JOIN addresses ON users.mailing_address=addresses.id WHERE users.user_name=?; ` stmt, err := db.Prepare(queryStr) if err != nil { return nil, err } res, err := stmt.Query(userName) if err != nil { return nil, err } defer res.Close() if !res.Next() { return nil, NoEntriesFoundError } user := new(User) err = res.Scan(&user.UserName, &user.ContactName, &user.SecretCodes, &user.CurrentToken, &user.ContactNumber, &user.EmailAddress, &user.MailingAddress.Street1, &user.MailingAddress.Street2, &user.MailingAddress.City, &user.MailingAddress.State, &user.MailingAddress.PostalCode, &user.MailingAddress.Country, &user.MailingAddress.CreatedAt, &user.MailingAddress.UpdatedAt, ) if err != nil { return nil, err } return user, nil } func (ldb *LDB) updateToken(userName string, token string) error { db := ldb.db updateStr := ` UPDATE users SET current_token=?,token_creation=? WHERE user_name=?; ` stmt, err := db.Prepare(updateStr) if err != nil { return err } _, err = stmt.Exec(token, time.Now(), userName) if err != nil { return err } return nil }