Vivek Shukla

How I use Postgres DB as Key Value Store in Rust with SQLx

Updated on

For my smaller application (or personal) or when I’m starting out a new app, I would like to keep the dependencies of my program to minimum. Since I’m going to need a DB anyway, so instead of using Redis and adding one more component to my tech stack, I would like to use Postgres as my key value store.

Prematurely thinking about scaling your system when you don’t even know if it would be successful enough to actually need all those bells and whistle is sure way to extend the deadline and procrastinate.

🔗Overview

We are going to create a trait called PgStore, which will implement all the functions to set, get, pop and delete the value.

Then you can implement this trait for your own types. Your types must derive serde’s Serialize and Deserialize for this to work.

🔗Dependencies

You need the following dependencies in your Cargo.toml

chrono = { version = "0.4", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
sqlx = { version = "0.8", features = [
    "runtime-tokio",
    "tls-rustls-ring-native-roots",
    "migrate",
    "postgres",
    "macros",
    "chrono",
] }

🔗Creating static PgPool connection

I don’t like to pass around DB connection across functions or threads or tasks. Instead I like to use tokio’s OnceCell which is a thread safe Cell which can be written to only once, it’s thread safe which means we can use async code in it.

use sqlx::PgPool;
use tokio::sync::OnceCell;

static PG_POOL: OnceCell<PgPool> = OnceCell::const_new();

pub async fn pg_pool() -> &'static PgPool {
    PG_POOL
        .get_or_init(async || {
            PgPool::connect(std::env("DATABASE_URL").expect("DATABASE_URL is not set"))
                .await
                .expect("❌ PgPool connection failed")
        })
        .await
}

Now we can just use pg_pool().await to fetch the db pool connection wherever we need.

🔗DB Migration

We are creating a new table kv_store in our Postgres DB. If you notice we are using UNLOGGED keyword to create the table, this tells postgres to not write WAL (write-ahead log) for kv_store table, which improves the performance at the expanse of data durability.

CREATE UNLOGGED TABLE "kv_store" (
    "key" VARCHAR(2048) PRIMARY KEY,
    "value" TEXT NOT NULL,
    "expires" TIMESTAMPTZ NOT NULL
);

🔗PgStore Trait

We are creating PgStore trait which will have methods to store, retrieve and delete the value in the kv_store table.

use chrono::{DateTime, Utc};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::time::Duration;

#[allow(async_fn_in_trait)]
pub trait PgStore: Serialize + DeserializeOwned {
    const EXPIRE_IN: usize; // seconds

    fn key_format(key: String) -> String;

    fn get_expire_time() -> DateTime<Utc> {
        Utc::now() + Duration::from_secs(Self::EXPIRE_IN as u64)
    }

    async fn set_ex(&self, key: String) -> Result<(), String> {
        let value =
            serde_json::to_string(&self).map_err(|_| "Failed to serialize value.".to_string())?;

        sqlx::query!(
            r#"INSERT INTO kv_store ("key", "value", expires) VALUES ($1, $2, $3)
            ON CONFLICT ("key") DO UPDATE SET "value" = $2, expires = $3"#,
            Self::key_format(key),
            value,
            Self::get_expire_time(),
        )
        .execute(db::pg_pool().await)
        .await
        .map_err(|_| "Failed to set value in kv_store".to_string())?;

        Ok(())
    }

    // it will just fetch the value if it exists
    async fn get(key: String) -> Result<Self, String> {
        let record = sqlx::query!(
            r#"SELECT "value" FROM kv_store WHERE "key" = $1"#,
            Self::key_format(key)
        )
        .fetch_one(db::pg_pool().await)
        .await
        .map_err(|_| "Failed to get value from kv_store".to_string())?;

        serde_json::from_str(&record.value).map_err(|_| "Failed to deserialize value".to_string())
    }

    // this will fetch the value while extending the expire time
    async fn get_ex(key: String) -> Result<Self, String> {
        let record = sqlx::query!(
            r#"UPDATE kv_store SET expires=$2 WHERE "key"=$1 RETURNING "value""#,
            Self::key_format(key),
            Self::get_expire_time(),
        )
        .fetch_one(db::pg_pool().await)
        .await
        .map_err(|_| "Failed to get value from kv_store".to_string())?;

        serde_json::from_str(&record.value).map_err(|_| "Failed to deserialize value".to_string())
    }

    // it will fetch the value while deleting the key
    async fn get_del(key: String) -> Result<Self, String> {
        let record = sqlx::query!(
            r#"DELETE FROM kv_store WHERE key=$1 RETURNING "value""#,
            Self::key_format(key)
        )
        .fetch_one(db::pg_pool().await)
        .await
        .map_err(|_| "Failed to get value from kv_store".to_string())?;

        serde_json::from_str(&record.value).map_err(|_| "Failed to deserialize value".to_string())
    }

    async fn del(key: String) -> Result<(), String> {
        sqlx::query!(
            r#"DELETE FROM kv_store WHERE key=$1"#,
            Self::key_format(key)
        )
        .execute(db::pg_pool().await)
        .await
        .map_err(|_| "Failed to delete value".to_string())?;

        Ok(())
    }
}

By defining our trait like PgStore: Serialize + DeserializeOwned, we are telling that whichever type is going to implement PgStore trait it must also have Serialize and DeserializeOwned implemented.

Following are the brief explaination of PgStore’s attributes:

EXPIRE_IN: This will be used to set the to be expired in seconds. It must be implemented by the Type you are using it for.

key_format: This function needs to be implemented. It will be used as a way to generate key.

set_ex: This method will store the type against the passed key in DB, overwriting value if it already exists.

get: It will fetch the value for a given key, error will be returned if it does not exist.

get_ex: Fetches the value while resetting the expire time.

get_del: Fetch the value while deleting the key. It’s more like pop.

del: Just deletes the value. Doesn’t return error even if key doesn’t exist.

🔗Cleaning up expired keys

Then we have a helper function pgstore_cleanup which will cleanup expired keys every 60 seconds. You must spawn to run this function in the background for all the time your program is running.

pub async fn pgstore_cleanup() {
    loop {
        // tracing::info!("🟡 PgStore cleanup is running...");
        sqlx::query!(r#"DELETE FROM kv_store WHERE expires < NOW()"#)
            .execute(db::pg_pool().await)
            .await
            .ok();
        tokio::time::sleep(Duration::from_secs(60)).await;
    }
}

🔗Usage

#[derive(serde::Serialize, serde::Deserialize)]
pub struct Game {
    pub id: usize,
    pub name: String,
    pub genre: String,
}

impl PgStore for Game {
    const EXPIRE_IN: usize = 3600;

    fn key_format(key: String) -> String {
        format!("game-{}", key)
    }
}

#[tokio::main]
async fn main() {
    // make sure to run this to delete expired keys regularly
    tokio::spawn(pgstore_cleanup());

    let game = Game {
        id: 1,
        name: "Skyrim".to_string(),
        genre: "rpg".to_string(),
    };

    // storing value in kv_store
    // it will be store with key: "game-1"
    let _ = game.set_ex(game.id.to_string()).await;

    // fetching value
    if let Ok(v) = Game::get(1.to_string()).await {
        println!("fetched game with name: {}", v.name);
    }

    // deleting key
    let _ = Game::del(1.to_string()).await;
}

Similarly you can do this with your own types, as long as they derive serde’s Serialize and Deserialize.