mas_storage_pg/user/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
13    UpstreamOAuthAuthorizationSession, User,
14};
15use mas_storage::{
16    Page, Pagination,
17    user::{BrowserSessionFilter, BrowserSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27    DatabaseError, DatabaseInconsistencyError,
28    filter::StatementExt,
29    iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
30    pagination::QueryBuilderExt,
31    tracing::ExecuteExt,
32};
33
34/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
35/// connection
36pub struct PgBrowserSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgBrowserSessionRepository<'c> {
41    /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48#[allow(clippy::struct_field_names)]
49#[derive(sqlx::FromRow)]
50#[sea_query::enum_def]
51struct SessionLookup {
52    user_session_id: Uuid,
53    user_session_created_at: DateTime<Utc>,
54    user_session_finished_at: Option<DateTime<Utc>>,
55    user_session_user_agent: Option<String>,
56    user_session_last_active_at: Option<DateTime<Utc>>,
57    user_session_last_active_ip: Option<IpAddr>,
58    user_id: Uuid,
59    user_username: String,
60    user_created_at: DateTime<Utc>,
61    user_locked_at: Option<DateTime<Utc>>,
62    user_deactivated_at: Option<DateTime<Utc>>,
63    user_can_request_admin: bool,
64    user_is_guest: bool,
65}
66
67impl TryFrom<SessionLookup> for BrowserSession {
68    type Error = DatabaseInconsistencyError;
69
70    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
71        let id = Ulid::from(value.user_id);
72        let user = User {
73            id,
74            username: value.user_username,
75            sub: id.to_string(),
76            created_at: value.user_created_at,
77            locked_at: value.user_locked_at,
78            deactivated_at: value.user_deactivated_at,
79            can_request_admin: value.user_can_request_admin,
80            is_guest: value.user_is_guest,
81        };
82
83        Ok(BrowserSession {
84            id: value.user_session_id.into(),
85            user,
86            created_at: value.user_session_created_at,
87            finished_at: value.user_session_finished_at,
88            user_agent: value.user_session_user_agent,
89            last_active_at: value.user_session_last_active_at,
90            last_active_ip: value.user_session_last_active_ip,
91        })
92    }
93}
94
95struct AuthenticationLookup {
96    user_session_authentication_id: Uuid,
97    created_at: DateTime<Utc>,
98    user_password_id: Option<Uuid>,
99    upstream_oauth_authorization_session_id: Option<Uuid>,
100}
101
102impl TryFrom<AuthenticationLookup> for Authentication {
103    type Error = DatabaseInconsistencyError;
104
105    fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
106        let id = Ulid::from(value.user_session_authentication_id);
107        let authentication_method = match (
108            value.user_password_id.map(Into::into),
109            value
110                .upstream_oauth_authorization_session_id
111                .map(Into::into),
112        ) {
113            (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
114            (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
115                upstream_oauth2_session_id,
116            },
117            (None, None) => AuthenticationMethod::Unknown,
118            _ => {
119                return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
120            }
121        };
122
123        Ok(Authentication {
124            id,
125            created_at: value.created_at,
126            authentication_method,
127        })
128    }
129}
130
131impl crate::filter::Filter for BrowserSessionFilter<'_> {
132    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
133        sea_query::Condition::all()
134            .add_option(self.user().map(|user| {
135                Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
136            }))
137            .add_option(self.state().map(|state| {
138                if state.is_active() {
139                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
140                } else {
141                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
142                }
143            }))
144            .add_option(self.last_active_after().map(|last_active_after| {
145                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
146            }))
147            .add_option(self.last_active_before().map(|last_active_before| {
148                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
149            }))
150            .add_option(self.authenticated_by_upstream_sessions().map(|filter| {
151                // For filtering by upstream sessions, we need to hop over the
152                // `user_session_authentications` table
153                let join_expr = Expr::col((
154                    UserSessionAuthentications::Table,
155                    UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
156                ))
157                .eq(Expr::col((
158                    UpstreamOAuthAuthorizationSessions::Table,
159                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
160                )));
161
162                Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
163                    Query::select()
164                        .expr(Expr::col((
165                            UserSessionAuthentications::Table,
166                            UserSessionAuthentications::UserSessionId,
167                        )))
168                        .from(UserSessionAuthentications::Table)
169                        .inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
170                        .apply_filter(filter)
171                        .take(),
172                )
173            }))
174    }
175}
176
177#[async_trait]
178impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
179    type Error = DatabaseError;
180
181    #[tracing::instrument(
182        name = "db.browser_session.lookup",
183        skip_all,
184        fields(
185            db.query.text,
186            user_session.id = %id,
187        ),
188        err,
189    )]
190    async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
191        let res = sqlx::query_as!(
192            SessionLookup,
193            r#"
194                SELECT s.user_session_id
195                     , s.created_at            AS "user_session_created_at"
196                     , s.finished_at           AS "user_session_finished_at"
197                     , s.user_agent            AS "user_session_user_agent"
198                     , s.last_active_at        AS "user_session_last_active_at"
199                     , s.last_active_ip        AS "user_session_last_active_ip: IpAddr"
200                     , u.user_id
201                     , u.username              AS "user_username"
202                     , u.created_at            AS "user_created_at"
203                     , u.locked_at             AS "user_locked_at"
204                     , u.deactivated_at        AS "user_deactivated_at"
205                     , u.can_request_admin     AS "user_can_request_admin"
206                     , u.is_guest              AS "user_is_guest"
207                FROM user_sessions s
208                INNER JOIN users u
209                    USING (user_id)
210                WHERE s.user_session_id = $1
211            "#,
212            Uuid::from(id),
213        )
214        .traced()
215        .fetch_optional(&mut *self.conn)
216        .await?;
217
218        let Some(res) = res else { return Ok(None) };
219
220        Ok(Some(res.try_into()?))
221    }
222
223    #[tracing::instrument(
224        name = "db.browser_session.add",
225        skip_all,
226        fields(
227            db.query.text,
228            %user.id,
229            user_session.id,
230        ),
231        err,
232    )]
233    async fn add(
234        &mut self,
235        rng: &mut (dyn RngCore + Send),
236        clock: &dyn Clock,
237        user: &User,
238        user_agent: Option<String>,
239    ) -> Result<BrowserSession, Self::Error> {
240        let created_at = clock.now();
241        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
242        tracing::Span::current().record("user_session.id", tracing::field::display(id));
243
244        sqlx::query!(
245            r#"
246                INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
247                VALUES ($1, $2, $3, $4)
248            "#,
249            Uuid::from(id),
250            Uuid::from(user.id),
251            created_at,
252            user_agent.as_deref(),
253        )
254        .traced()
255        .execute(&mut *self.conn)
256        .await?;
257
258        let session = BrowserSession {
259            id,
260            // XXX
261            user: user.clone(),
262            created_at,
263            finished_at: None,
264            user_agent,
265            last_active_at: None,
266            last_active_ip: None,
267        };
268
269        Ok(session)
270    }
271
272    #[tracing::instrument(
273        name = "db.browser_session.finish",
274        skip_all,
275        fields(
276            db.query.text,
277            %user_session.id,
278        ),
279        err,
280    )]
281    async fn finish(
282        &mut self,
283        clock: &dyn Clock,
284        mut user_session: BrowserSession,
285    ) -> Result<BrowserSession, Self::Error> {
286        let finished_at = clock.now();
287        let res = sqlx::query!(
288            r#"
289                UPDATE user_sessions
290                SET finished_at = $1
291                WHERE user_session_id = $2
292            "#,
293            finished_at,
294            Uuid::from(user_session.id),
295        )
296        .traced()
297        .execute(&mut *self.conn)
298        .await?;
299
300        user_session.finished_at = Some(finished_at);
301
302        DatabaseError::ensure_affected_rows(&res, 1)?;
303
304        Ok(user_session)
305    }
306
307    #[tracing::instrument(
308        name = "db.browser_session.finish_bulk",
309        skip_all,
310        fields(
311            db.query.text,
312        ),
313        err,
314    )]
315    async fn finish_bulk(
316        &mut self,
317        clock: &dyn Clock,
318        filter: BrowserSessionFilter<'_>,
319    ) -> Result<usize, Self::Error> {
320        let finished_at = clock.now();
321        let (sql, arguments) = sea_query::Query::update()
322            .table(UserSessions::Table)
323            .value(UserSessions::FinishedAt, finished_at)
324            .apply_filter(filter)
325            .build_sqlx(PostgresQueryBuilder);
326
327        let res = sqlx::query_with(&sql, arguments)
328            .traced()
329            .execute(&mut *self.conn)
330            .await?;
331
332        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
333    }
334
335    #[tracing::instrument(
336        name = "db.browser_session.list",
337        skip_all,
338        fields(
339            db.query.text,
340        ),
341        err,
342    )]
343    async fn list(
344        &mut self,
345        filter: BrowserSessionFilter<'_>,
346        pagination: Pagination,
347    ) -> Result<Page<BrowserSession>, Self::Error> {
348        let (sql, arguments) = sea_query::Query::select()
349            .expr_as(
350                Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
351                SessionLookupIden::UserSessionId,
352            )
353            .expr_as(
354                Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
355                SessionLookupIden::UserSessionCreatedAt,
356            )
357            .expr_as(
358                Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
359                SessionLookupIden::UserSessionFinishedAt,
360            )
361            .expr_as(
362                Expr::col((UserSessions::Table, UserSessions::UserAgent)),
363                SessionLookupIden::UserSessionUserAgent,
364            )
365            .expr_as(
366                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
367                SessionLookupIden::UserSessionLastActiveAt,
368            )
369            .expr_as(
370                Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
371                SessionLookupIden::UserSessionLastActiveIp,
372            )
373            .expr_as(
374                Expr::col((Users::Table, Users::UserId)),
375                SessionLookupIden::UserId,
376            )
377            .expr_as(
378                Expr::col((Users::Table, Users::Username)),
379                SessionLookupIden::UserUsername,
380            )
381            .expr_as(
382                Expr::col((Users::Table, Users::CreatedAt)),
383                SessionLookupIden::UserCreatedAt,
384            )
385            .expr_as(
386                Expr::col((Users::Table, Users::LockedAt)),
387                SessionLookupIden::UserLockedAt,
388            )
389            .expr_as(
390                Expr::col((Users::Table, Users::DeactivatedAt)),
391                SessionLookupIden::UserDeactivatedAt,
392            )
393            .expr_as(
394                Expr::col((Users::Table, Users::CanRequestAdmin)),
395                SessionLookupIden::UserCanRequestAdmin,
396            )
397            .expr_as(
398                Expr::col((Users::Table, Users::IsGuest)),
399                SessionLookupIden::UserIsGuest,
400            )
401            .from(UserSessions::Table)
402            .inner_join(
403                Users::Table,
404                Expr::col((UserSessions::Table, UserSessions::UserId))
405                    .equals((Users::Table, Users::UserId)),
406            )
407            .apply_filter(filter)
408            .generate_pagination(
409                (UserSessions::Table, UserSessions::UserSessionId),
410                pagination,
411            )
412            .build_sqlx(PostgresQueryBuilder);
413
414        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
415            .traced()
416            .fetch_all(&mut *self.conn)
417            .await?;
418
419        let page = pagination
420            .process(edges)
421            .try_map(BrowserSession::try_from)?;
422
423        Ok(page)
424    }
425
426    #[tracing::instrument(
427        name = "db.browser_session.count",
428        skip_all,
429        fields(
430            db.query.text,
431        ),
432        err,
433    )]
434    async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
435        let (sql, arguments) = sea_query::Query::select()
436            .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
437            .from(UserSessions::Table)
438            .apply_filter(filter)
439            .build_sqlx(PostgresQueryBuilder);
440
441        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
442            .traced()
443            .fetch_one(&mut *self.conn)
444            .await?;
445
446        count
447            .try_into()
448            .map_err(DatabaseError::to_invalid_operation)
449    }
450
451    #[tracing::instrument(
452        name = "db.browser_session.authenticate_with_password",
453        skip_all,
454        fields(
455            db.query.text,
456            %user_session.id,
457            %user_password.id,
458            user_session_authentication.id,
459        ),
460        err,
461    )]
462    async fn authenticate_with_password(
463        &mut self,
464        rng: &mut (dyn RngCore + Send),
465        clock: &dyn Clock,
466        user_session: &BrowserSession,
467        user_password: &Password,
468    ) -> Result<Authentication, Self::Error> {
469        let created_at = clock.now();
470        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
471        tracing::Span::current().record(
472            "user_session_authentication.id",
473            tracing::field::display(id),
474        );
475
476        sqlx::query!(
477            r#"
478                INSERT INTO user_session_authentications
479                    (user_session_authentication_id, user_session_id, created_at, user_password_id)
480                VALUES ($1, $2, $3, $4)
481            "#,
482            Uuid::from(id),
483            Uuid::from(user_session.id),
484            created_at,
485            Uuid::from(user_password.id),
486        )
487        .traced()
488        .execute(&mut *self.conn)
489        .await?;
490
491        Ok(Authentication {
492            id,
493            created_at,
494            authentication_method: AuthenticationMethod::Password {
495                user_password_id: user_password.id,
496            },
497        })
498    }
499
500    #[tracing::instrument(
501        name = "db.browser_session.authenticate_with_upstream",
502        skip_all,
503        fields(
504            db.query.text,
505            %user_session.id,
506            %upstream_oauth_session.id,
507            user_session_authentication.id,
508        ),
509        err,
510    )]
511    async fn authenticate_with_upstream(
512        &mut self,
513        rng: &mut (dyn RngCore + Send),
514        clock: &dyn Clock,
515        user_session: &BrowserSession,
516        upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
517    ) -> Result<Authentication, Self::Error> {
518        let created_at = clock.now();
519        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
520        tracing::Span::current().record(
521            "user_session_authentication.id",
522            tracing::field::display(id),
523        );
524
525        sqlx::query!(
526            r#"
527                INSERT INTO user_session_authentications
528                    (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
529                VALUES ($1, $2, $3, $4)
530            "#,
531            Uuid::from(id),
532            Uuid::from(user_session.id),
533            created_at,
534            Uuid::from(upstream_oauth_session.id),
535        )
536        .traced()
537        .execute(&mut *self.conn)
538        .await?;
539
540        Ok(Authentication {
541            id,
542            created_at,
543            authentication_method: AuthenticationMethod::UpstreamOAuth2 {
544                upstream_oauth2_session_id: upstream_oauth_session.id,
545            },
546        })
547    }
548
549    #[tracing::instrument(
550        name = "db.browser_session.get_last_authentication",
551        skip_all,
552        fields(
553            db.query.text,
554            %user_session.id,
555        ),
556        err,
557    )]
558    async fn get_last_authentication(
559        &mut self,
560        user_session: &BrowserSession,
561    ) -> Result<Option<Authentication>, Self::Error> {
562        let authentication = sqlx::query_as!(
563            AuthenticationLookup,
564            r#"
565                SELECT user_session_authentication_id
566                     , created_at
567                     , user_password_id
568                     , upstream_oauth_authorization_session_id
569                FROM user_session_authentications
570                WHERE user_session_id = $1
571                ORDER BY created_at DESC
572                LIMIT 1
573            "#,
574            Uuid::from(user_session.id),
575        )
576        .traced()
577        .fetch_optional(&mut *self.conn)
578        .await?;
579
580        let Some(authentication) = authentication else {
581            return Ok(None);
582        };
583
584        let authentication = Authentication::try_from(authentication)?;
585        Ok(Some(authentication))
586    }
587
588    #[tracing::instrument(
589        name = "db.browser_session.record_batch_activity",
590        skip_all,
591        fields(
592            db.query.text,
593        ),
594        err,
595    )]
596    async fn record_batch_activity(
597        &mut self,
598        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
599    ) -> Result<(), Self::Error> {
600        // Sort the activity by ID, so that when batching the updates, Postgres
601        // locks the rows in a stable order, preventing deadlocks
602        activities.sort_unstable();
603        let mut ids = Vec::with_capacity(activities.len());
604        let mut last_activities = Vec::with_capacity(activities.len());
605        let mut ips = Vec::with_capacity(activities.len());
606
607        for (id, last_activity, ip) in activities {
608            ids.push(Uuid::from(id));
609            last_activities.push(last_activity);
610            ips.push(ip);
611        }
612
613        let res = sqlx::query!(
614            r#"
615                UPDATE user_sessions
616                SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
617                  , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
618                FROM (
619                    SELECT *
620                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
621                        AS t(user_session_id, last_active_at, last_active_ip)
622                ) AS t
623                WHERE user_sessions.user_session_id = t.user_session_id
624            "#,
625            &ids,
626            &last_activities,
627            &ips as &[Option<IpAddr>],
628        )
629        .traced()
630        .execute(&mut *self.conn)
631        .await?;
632
633        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
634
635        Ok(())
636    }
637}