url-shortener/postgres.go

123 lines
2.6 KiB
Go
Raw Permalink Normal View History

package main
import (
"database/sql"
"errors"
"fmt"
_ "github.com/lib/pq"
"log"
"os"
)
type DbConfig struct {
dbName string
dbUser string
dbPassword string
dbHost string
dbPort string
}
func connectDB() *sql.DB {
dbConfig := &DbConfig{
dbName: os.Getenv("DB_NAME"),
dbUser: os.Getenv("DB_USER"),
dbHost: os.Getenv("DB_HOST"),
dbPort: os.Getenv("DB_PORT"),
dbPassword: os.Getenv("DB_PASSWORD"),
}
connectStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%s sslmode=disable", dbConfig.dbUser, dbConfig.dbPassword, dbConfig.dbName, dbConfig.dbHost, dbConfig.dbPort)
db, err := sql.Open("postgres", connectStr)
if err != nil {
fmt.Println("Error connecting to database", err)
panic(err)
}
err = db.Ping()
if err != nil {
log.Fatal("Failed to connect to the database:", err)
}
fmt.Println("Connected to the database")
_, tbErr := db.Exec("CREATE TABLE IF NOT EXISTS urls (id bigint NOT NULL PRIMARY KEY , url TEXT NOT NULL, code varchar(18) NULL, createdAt TIMESTAMP DEFAULT NOW(), updatedAt TIMESTAMP DEFAULT NOW())")
if tbErr != nil {
log.Fatal("Error creating table:", tbErr)
}
fmt.Printf("Table '%s' created successfully.\n", "urls")
return db
}
func insertURL(db *sql.DB, url string) (string, error) {
id := generateNewID()
code := intToBase62(id)
_, err := db.Exec("INSERT INTO urls(id, url, code) VALUES ($1, $2, $3)", id, url, code)
if err != nil {
log.Fatal("Error inserting url:", err)
return "", err
}
return code, nil
}
func checkURLAlreadyExists(db *sql.DB, url string) (string, error) {
var code string
rows, err := db.Query(`SELECT code from urls where url=$1`, url)
if err != nil {
log.Fatalln("Error querying rows:", err)
return "", err
} else {
for rows.Next() {
err := rows.Scan(&code)
if err != nil {
return "", err
}
return code, nil
}
}
return "", err
}
func getUrlByCode(db *sql.DB, code string) (string, error) {
var url string
rows, err := db.Query("SELECT url FROM urls WHERE code=$1", code)
if err != nil {
log.Fatalln(err)
return "", err
}
for rows.Next() {
err = rows.Scan(&url)
return url, err
}
return "", errors.New("code not found")
}
func generateNewID() uint64 {
id, err := sf.NextID()
if err != nil {
fmt.Println(err)
panic(err)
}
return id
}
func intToBase62(n uint64) string {
if n == 0 {
return string(base62[0])
}
var result []byte
for n > 0 {
result = append(result, base62[n%62])
n /= 62
}
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
return string(result)
}