mas_storage_pg/user/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21    DatabaseError,
22    filter::{Filter, StatementExt},
23    iden::Users,
24    pagination::QueryBuilderExt,
25    tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40    email::PgUserEmailRepository, password::PgUserPasswordRepository,
41    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43    terms::PgUserTermsRepository,
44};
45
46/// An implementation of [`UserRepository`] for a PostgreSQL connection
47pub struct PgUserRepository<'c> {
48    conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
53    pub fn new(conn: &'c mut PgConnection) -> Self {
54        Self { conn }
55    }
56}
57
58mod priv_ {
59    // The enum_def macro generates a public enum, which we don't want, because it
60    // triggers the missing docs warning
61    #![allow(missing_docs)]
62
63    use chrono::{DateTime, Utc};
64    use sea_query::enum_def;
65    use uuid::Uuid;
66
67    #[derive(Debug, Clone, sqlx::FromRow)]
68    #[enum_def]
69    pub(super) struct UserLookup {
70        pub(super) user_id: Uuid,
71        pub(super) username: String,
72        pub(super) created_at: DateTime<Utc>,
73        pub(super) locked_at: Option<DateTime<Utc>>,
74        pub(super) deactivated_at: Option<DateTime<Utc>>,
75        pub(super) can_request_admin: bool,
76        pub(super) is_guest: bool,
77    }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83    fn from(value: UserLookup) -> Self {
84        let id = value.user_id.into();
85        Self {
86            id,
87            username: value.username,
88            sub: id.to_string(),
89            created_at: value.created_at,
90            locked_at: value.locked_at,
91            deactivated_at: value.deactivated_at,
92            can_request_admin: value.can_request_admin,
93            is_guest: value.is_guest,
94        }
95    }
96}
97
98impl Filter for UserFilter<'_> {
99    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
100        sea_query::Condition::all()
101            .add_option(self.state().map(|state| {
102                match state {
103                    mas_storage::user::UserState::Deactivated => {
104                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
105                    }
106                    mas_storage::user::UserState::Locked => {
107                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
108                    }
109                    mas_storage::user::UserState::Active => {
110                        Expr::col((Users::Table, Users::LockedAt))
111                            .is_null()
112                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
113                    }
114                }
115            }))
116            .add_option(self.can_request_admin().map(|can_request_admin| {
117                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
118            }))
119            .add_option(
120                self.is_guest()
121                    .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
122            )
123    }
124}
125
126#[async_trait]
127impl UserRepository for PgUserRepository<'_> {
128    type Error = DatabaseError;
129
130    #[tracing::instrument(
131        name = "db.user.lookup",
132        skip_all,
133        fields(
134            db.query.text,
135            user.id = %id,
136        ),
137        err,
138    )]
139    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
140        let res = sqlx::query_as!(
141            UserLookup,
142            r#"
143                SELECT user_id
144                     , username
145                     , created_at
146                     , locked_at
147                     , deactivated_at
148                     , can_request_admin
149                     , is_guest
150                FROM users
151                WHERE user_id = $1
152            "#,
153            Uuid::from(id),
154        )
155        .traced()
156        .fetch_optional(&mut *self.conn)
157        .await?;
158
159        let Some(res) = res else { return Ok(None) };
160
161        Ok(Some(res.into()))
162    }
163
164    #[tracing::instrument(
165        name = "db.user.find_by_username",
166        skip_all,
167        fields(
168            db.query.text,
169            user.username = username,
170        ),
171        err,
172    )]
173    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
174        // We may have multiple users with the same username, but with a different
175        // casing. In this case, we want to return the one which matches the exact
176        // casing
177        let res = sqlx::query_as!(
178            UserLookup,
179            r#"
180                SELECT user_id
181                     , username
182                     , created_at
183                     , locked_at
184                     , deactivated_at
185                     , can_request_admin
186                     , is_guest
187                FROM users
188                WHERE LOWER(username) = LOWER($1)
189            "#,
190            username,
191        )
192        .traced()
193        .fetch_all(&mut *self.conn)
194        .await?;
195
196        match &res[..] {
197            // Happy path: there is only one user matching the username…
198            [user] => Ok(Some(user.clone().into())),
199            // …or none.
200            [] => Ok(None),
201            list => {
202                // If there are multiple users with the same username, we want to
203                // return the one which matches the exact casing
204                if let Some(user) = list.iter().find(|user| user.username == username) {
205                    Ok(Some(user.clone().into()))
206                } else {
207                    // If none match exactly, we prefer to return nothing
208                    Ok(None)
209                }
210            }
211        }
212    }
213
214    #[tracing::instrument(
215        name = "db.user.add",
216        skip_all,
217        fields(
218            db.query.text,
219            user.username = username,
220            user.id,
221        ),
222        err,
223    )]
224    async fn add(
225        &mut self,
226        rng: &mut (dyn RngCore + Send),
227        clock: &dyn Clock,
228        username: String,
229    ) -> Result<User, Self::Error> {
230        let created_at = clock.now();
231        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
232        tracing::Span::current().record("user.id", tracing::field::display(id));
233
234        let res = sqlx::query!(
235            r#"
236                INSERT INTO users (user_id, username, created_at)
237                VALUES ($1, $2, $3)
238                ON CONFLICT (username) DO NOTHING
239            "#,
240            Uuid::from(id),
241            username,
242            created_at,
243        )
244        .traced()
245        .execute(&mut *self.conn)
246        .await?;
247
248        // If the user already exists, want to return an error but not poison the
249        // transaction
250        DatabaseError::ensure_affected_rows(&res, 1)?;
251
252        Ok(User {
253            id,
254            username,
255            sub: id.to_string(),
256            created_at,
257            locked_at: None,
258            deactivated_at: None,
259            can_request_admin: false,
260            is_guest: false,
261        })
262    }
263
264    #[tracing::instrument(
265        name = "db.user.exists",
266        skip_all,
267        fields(
268            db.query.text,
269            user.username = username,
270        ),
271        err,
272    )]
273    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
274        let exists = sqlx::query_scalar!(
275            r#"
276                SELECT EXISTS(
277                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
278                ) AS "exists!"
279            "#,
280            username
281        )
282        .traced()
283        .fetch_one(&mut *self.conn)
284        .await?;
285
286        Ok(exists)
287    }
288
289    #[tracing::instrument(
290        name = "db.user.lock",
291        skip_all,
292        fields(
293            db.query.text,
294            %user.id,
295        ),
296        err,
297    )]
298    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
299        if user.locked_at.is_some() {
300            return Ok(user);
301        }
302
303        let locked_at = clock.now();
304        let res = sqlx::query!(
305            r#"
306                UPDATE users
307                SET locked_at = $1
308                WHERE user_id = $2
309            "#,
310            locked_at,
311            Uuid::from(user.id),
312        )
313        .traced()
314        .execute(&mut *self.conn)
315        .await?;
316
317        DatabaseError::ensure_affected_rows(&res, 1)?;
318
319        user.locked_at = Some(locked_at);
320
321        Ok(user)
322    }
323
324    #[tracing::instrument(
325        name = "db.user.unlock",
326        skip_all,
327        fields(
328            db.query.text,
329            %user.id,
330        ),
331        err,
332    )]
333    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
334        if user.locked_at.is_none() {
335            return Ok(user);
336        }
337
338        let res = sqlx::query!(
339            r#"
340                UPDATE users
341                SET locked_at = NULL
342                WHERE user_id = $1
343            "#,
344            Uuid::from(user.id),
345        )
346        .traced()
347        .execute(&mut *self.conn)
348        .await?;
349
350        DatabaseError::ensure_affected_rows(&res, 1)?;
351
352        user.locked_at = None;
353
354        Ok(user)
355    }
356
357    #[tracing::instrument(
358        name = "db.user.deactivate",
359        skip_all,
360        fields(
361            db.query.text,
362            %user.id,
363        ),
364        err,
365    )]
366    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
367        if user.deactivated_at.is_some() {
368            return Ok(user);
369        }
370
371        let deactivated_at = clock.now();
372        let res = sqlx::query!(
373            r#"
374                UPDATE users
375                SET deactivated_at = $2
376                WHERE user_id = $1
377                  AND deactivated_at IS NULL
378            "#,
379            Uuid::from(user.id),
380            deactivated_at,
381        )
382        .traced()
383        .execute(&mut *self.conn)
384        .await?;
385
386        DatabaseError::ensure_affected_rows(&res, 1)?;
387
388        user.deactivated_at = Some(deactivated_at);
389
390        Ok(user)
391    }
392
393    #[tracing::instrument(
394        name = "db.user.reactivate",
395        skip_all,
396        fields(
397            db.query.text,
398            %user.id,
399        ),
400        err,
401    )]
402    async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
403        if user.deactivated_at.is_none() {
404            return Ok(user);
405        }
406
407        let res = sqlx::query!(
408            r#"
409                UPDATE users
410                SET deactivated_at = NULL
411                WHERE user_id = $1
412            "#,
413            Uuid::from(user.id),
414        )
415        .traced()
416        .execute(&mut *self.conn)
417        .await?;
418
419        DatabaseError::ensure_affected_rows(&res, 1)?;
420
421        user.deactivated_at = None;
422
423        Ok(user)
424    }
425
426    #[tracing::instrument(
427        name = "db.user.set_can_request_admin",
428        skip_all,
429        fields(
430            db.query.text,
431            %user.id,
432            user.can_request_admin = can_request_admin,
433        ),
434        err,
435    )]
436    async fn set_can_request_admin(
437        &mut self,
438        mut user: User,
439        can_request_admin: bool,
440    ) -> Result<User, Self::Error> {
441        let res = sqlx::query!(
442            r#"
443                UPDATE users
444                SET can_request_admin = $2
445                WHERE user_id = $1
446            "#,
447            Uuid::from(user.id),
448            can_request_admin,
449        )
450        .traced()
451        .execute(&mut *self.conn)
452        .await?;
453
454        DatabaseError::ensure_affected_rows(&res, 1)?;
455
456        user.can_request_admin = can_request_admin;
457
458        Ok(user)
459    }
460
461    #[tracing::instrument(
462        name = "db.user.list",
463        skip_all,
464        fields(
465            db.query.text,
466        ),
467        err,
468    )]
469    async fn list(
470        &mut self,
471        filter: UserFilter<'_>,
472        pagination: mas_storage::Pagination,
473    ) -> Result<mas_storage::Page<User>, Self::Error> {
474        let (sql, arguments) = Query::select()
475            .expr_as(
476                Expr::col((Users::Table, Users::UserId)),
477                UserLookupIden::UserId,
478            )
479            .expr_as(
480                Expr::col((Users::Table, Users::Username)),
481                UserLookupIden::Username,
482            )
483            .expr_as(
484                Expr::col((Users::Table, Users::CreatedAt)),
485                UserLookupIden::CreatedAt,
486            )
487            .expr_as(
488                Expr::col((Users::Table, Users::LockedAt)),
489                UserLookupIden::LockedAt,
490            )
491            .expr_as(
492                Expr::col((Users::Table, Users::DeactivatedAt)),
493                UserLookupIden::DeactivatedAt,
494            )
495            .expr_as(
496                Expr::col((Users::Table, Users::CanRequestAdmin)),
497                UserLookupIden::CanRequestAdmin,
498            )
499            .expr_as(
500                Expr::col((Users::Table, Users::IsGuest)),
501                UserLookupIden::IsGuest,
502            )
503            .from(Users::Table)
504            .apply_filter(filter)
505            .generate_pagination((Users::Table, Users::UserId), pagination)
506            .build_sqlx(PostgresQueryBuilder);
507
508        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
509            .traced()
510            .fetch_all(&mut *self.conn)
511            .await?;
512
513        let page = pagination.process(edges).map(User::from);
514
515        Ok(page)
516    }
517
518    #[tracing::instrument(
519        name = "db.user.count",
520        skip_all,
521        fields(
522            db.query.text,
523        ),
524        err,
525    )]
526    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
527        let (sql, arguments) = Query::select()
528            .expr(Expr::col((Users::Table, Users::UserId)).count())
529            .from(Users::Table)
530            .apply_filter(filter)
531            .build_sqlx(PostgresQueryBuilder);
532
533        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
534            .traced()
535            .fetch_one(&mut *self.conn)
536            .await?;
537
538        count
539            .try_into()
540            .map_err(DatabaseError::to_invalid_operation)
541    }
542
543    #[tracing::instrument(
544        name = "db.user.acquire_lock_for_sync",
545        skip_all,
546        fields(
547            db.query.text,
548            user.id = %user.id,
549        ),
550        err,
551    )]
552    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
553        // XXX: this lock isn't stictly scoped to users, but as we don't use many
554        // postgres advisory locks, it's fine for now. Later on, we could use row-level
555        // locks to make sure we don't get into trouble
556
557        // Convert the user ID to a u128 and grab the lower 64 bits
558        // As this includes 64bit of the random part of the ULID, it should be random
559        // enough to not collide
560        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
561
562        // Use a PG advisory lock, which will be released when the transaction is
563        // committed or rolled back
564        sqlx::query!(
565            r#"
566                SELECT pg_advisory_xact_lock($1)
567            "#,
568            lock_id,
569        )
570        .traced()
571        .execute(&mut *self.conn)
572        .await?;
573
574        Ok(())
575    }
576}