1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
13 UpstreamOAuthAuthorizationSession, User,
14};
15use mas_storage::{
16 Page, Pagination,
17 user::{BrowserSessionFilter, BrowserSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError, DatabaseInconsistencyError,
28 filter::StatementExt,
29 iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34pub struct PgBrowserSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgBrowserSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[allow(clippy::struct_field_names)]
49#[derive(sqlx::FromRow)]
50#[sea_query::enum_def]
51struct SessionLookup {
52 user_session_id: Uuid,
53 user_session_created_at: DateTime<Utc>,
54 user_session_finished_at: Option<DateTime<Utc>>,
55 user_session_user_agent: Option<String>,
56 user_session_last_active_at: Option<DateTime<Utc>>,
57 user_session_last_active_ip: Option<IpAddr>,
58 user_id: Uuid,
59 user_username: String,
60 user_created_at: DateTime<Utc>,
61 user_locked_at: Option<DateTime<Utc>>,
62 user_deactivated_at: Option<DateTime<Utc>>,
63 user_can_request_admin: bool,
64 user_is_guest: bool,
65}
66
67impl TryFrom<SessionLookup> for BrowserSession {
68 type Error = DatabaseInconsistencyError;
69
70 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
71 let id = Ulid::from(value.user_id);
72 let user = User {
73 id,
74 username: value.user_username,
75 sub: id.to_string(),
76 created_at: value.user_created_at,
77 locked_at: value.user_locked_at,
78 deactivated_at: value.user_deactivated_at,
79 can_request_admin: value.user_can_request_admin,
80 is_guest: value.user_is_guest,
81 };
82
83 Ok(BrowserSession {
84 id: value.user_session_id.into(),
85 user,
86 created_at: value.user_session_created_at,
87 finished_at: value.user_session_finished_at,
88 user_agent: value.user_session_user_agent,
89 last_active_at: value.user_session_last_active_at,
90 last_active_ip: value.user_session_last_active_ip,
91 })
92 }
93}
94
95struct AuthenticationLookup {
96 user_session_authentication_id: Uuid,
97 created_at: DateTime<Utc>,
98 user_password_id: Option<Uuid>,
99 upstream_oauth_authorization_session_id: Option<Uuid>,
100}
101
102impl TryFrom<AuthenticationLookup> for Authentication {
103 type Error = DatabaseInconsistencyError;
104
105 fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
106 let id = Ulid::from(value.user_session_authentication_id);
107 let authentication_method = match (
108 value.user_password_id.map(Into::into),
109 value
110 .upstream_oauth_authorization_session_id
111 .map(Into::into),
112 ) {
113 (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
114 (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
115 upstream_oauth2_session_id,
116 },
117 (None, None) => AuthenticationMethod::Unknown,
118 _ => {
119 return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
120 }
121 };
122
123 Ok(Authentication {
124 id,
125 created_at: value.created_at,
126 authentication_method,
127 })
128 }
129}
130
131impl crate::filter::Filter for BrowserSessionFilter<'_> {
132 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
133 sea_query::Condition::all()
134 .add_option(self.user().map(|user| {
135 Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
136 }))
137 .add_option(self.state().map(|state| {
138 if state.is_active() {
139 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
140 } else {
141 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
142 }
143 }))
144 .add_option(self.last_active_after().map(|last_active_after| {
145 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
146 }))
147 .add_option(self.last_active_before().map(|last_active_before| {
148 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
149 }))
150 .add_option(self.authenticated_by_upstream_sessions().map(|filter| {
151 let join_expr = Expr::col((
154 UserSessionAuthentications::Table,
155 UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
156 ))
157 .eq(Expr::col((
158 UpstreamOAuthAuthorizationSessions::Table,
159 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
160 )));
161
162 Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
163 Query::select()
164 .expr(Expr::col((
165 UserSessionAuthentications::Table,
166 UserSessionAuthentications::UserSessionId,
167 )))
168 .from(UserSessionAuthentications::Table)
169 .inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
170 .apply_filter(filter)
171 .take(),
172 )
173 }))
174 }
175}
176
177#[async_trait]
178impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
179 type Error = DatabaseError;
180
181 #[tracing::instrument(
182 name = "db.browser_session.lookup",
183 skip_all,
184 fields(
185 db.query.text,
186 user_session.id = %id,
187 ),
188 err,
189 )]
190 async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
191 let res = sqlx::query_as!(
192 SessionLookup,
193 r#"
194 SELECT s.user_session_id
195 , s.created_at AS "user_session_created_at"
196 , s.finished_at AS "user_session_finished_at"
197 , s.user_agent AS "user_session_user_agent"
198 , s.last_active_at AS "user_session_last_active_at"
199 , s.last_active_ip AS "user_session_last_active_ip: IpAddr"
200 , u.user_id
201 , u.username AS "user_username"
202 , u.created_at AS "user_created_at"
203 , u.locked_at AS "user_locked_at"
204 , u.deactivated_at AS "user_deactivated_at"
205 , u.can_request_admin AS "user_can_request_admin"
206 , u.is_guest AS "user_is_guest"
207 FROM user_sessions s
208 INNER JOIN users u
209 USING (user_id)
210 WHERE s.user_session_id = $1
211 "#,
212 Uuid::from(id),
213 )
214 .traced()
215 .fetch_optional(&mut *self.conn)
216 .await?;
217
218 let Some(res) = res else { return Ok(None) };
219
220 Ok(Some(res.try_into()?))
221 }
222
223 #[tracing::instrument(
224 name = "db.browser_session.add",
225 skip_all,
226 fields(
227 db.query.text,
228 %user.id,
229 user_session.id,
230 ),
231 err,
232 )]
233 async fn add(
234 &mut self,
235 rng: &mut (dyn RngCore + Send),
236 clock: &dyn Clock,
237 user: &User,
238 user_agent: Option<String>,
239 ) -> Result<BrowserSession, Self::Error> {
240 let created_at = clock.now();
241 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
242 tracing::Span::current().record("user_session.id", tracing::field::display(id));
243
244 sqlx::query!(
245 r#"
246 INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
247 VALUES ($1, $2, $3, $4)
248 "#,
249 Uuid::from(id),
250 Uuid::from(user.id),
251 created_at,
252 user_agent.as_deref(),
253 )
254 .traced()
255 .execute(&mut *self.conn)
256 .await?;
257
258 let session = BrowserSession {
259 id,
260 user: user.clone(),
262 created_at,
263 finished_at: None,
264 user_agent,
265 last_active_at: None,
266 last_active_ip: None,
267 };
268
269 Ok(session)
270 }
271
272 #[tracing::instrument(
273 name = "db.browser_session.finish",
274 skip_all,
275 fields(
276 db.query.text,
277 %user_session.id,
278 ),
279 err,
280 )]
281 async fn finish(
282 &mut self,
283 clock: &dyn Clock,
284 mut user_session: BrowserSession,
285 ) -> Result<BrowserSession, Self::Error> {
286 let finished_at = clock.now();
287 let res = sqlx::query!(
288 r#"
289 UPDATE user_sessions
290 SET finished_at = $1
291 WHERE user_session_id = $2
292 "#,
293 finished_at,
294 Uuid::from(user_session.id),
295 )
296 .traced()
297 .execute(&mut *self.conn)
298 .await?;
299
300 user_session.finished_at = Some(finished_at);
301
302 DatabaseError::ensure_affected_rows(&res, 1)?;
303
304 Ok(user_session)
305 }
306
307 #[tracing::instrument(
308 name = "db.browser_session.finish_bulk",
309 skip_all,
310 fields(
311 db.query.text,
312 ),
313 err,
314 )]
315 async fn finish_bulk(
316 &mut self,
317 clock: &dyn Clock,
318 filter: BrowserSessionFilter<'_>,
319 ) -> Result<usize, Self::Error> {
320 let finished_at = clock.now();
321 let (sql, arguments) = sea_query::Query::update()
322 .table(UserSessions::Table)
323 .value(UserSessions::FinishedAt, finished_at)
324 .apply_filter(filter)
325 .build_sqlx(PostgresQueryBuilder);
326
327 let res = sqlx::query_with(&sql, arguments)
328 .traced()
329 .execute(&mut *self.conn)
330 .await?;
331
332 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
333 }
334
335 #[tracing::instrument(
336 name = "db.browser_session.list",
337 skip_all,
338 fields(
339 db.query.text,
340 ),
341 err,
342 )]
343 async fn list(
344 &mut self,
345 filter: BrowserSessionFilter<'_>,
346 pagination: Pagination,
347 ) -> Result<Page<BrowserSession>, Self::Error> {
348 let (sql, arguments) = sea_query::Query::select()
349 .expr_as(
350 Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
351 SessionLookupIden::UserSessionId,
352 )
353 .expr_as(
354 Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
355 SessionLookupIden::UserSessionCreatedAt,
356 )
357 .expr_as(
358 Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
359 SessionLookupIden::UserSessionFinishedAt,
360 )
361 .expr_as(
362 Expr::col((UserSessions::Table, UserSessions::UserAgent)),
363 SessionLookupIden::UserSessionUserAgent,
364 )
365 .expr_as(
366 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
367 SessionLookupIden::UserSessionLastActiveAt,
368 )
369 .expr_as(
370 Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
371 SessionLookupIden::UserSessionLastActiveIp,
372 )
373 .expr_as(
374 Expr::col((Users::Table, Users::UserId)),
375 SessionLookupIden::UserId,
376 )
377 .expr_as(
378 Expr::col((Users::Table, Users::Username)),
379 SessionLookupIden::UserUsername,
380 )
381 .expr_as(
382 Expr::col((Users::Table, Users::CreatedAt)),
383 SessionLookupIden::UserCreatedAt,
384 )
385 .expr_as(
386 Expr::col((Users::Table, Users::LockedAt)),
387 SessionLookupIden::UserLockedAt,
388 )
389 .expr_as(
390 Expr::col((Users::Table, Users::DeactivatedAt)),
391 SessionLookupIden::UserDeactivatedAt,
392 )
393 .expr_as(
394 Expr::col((Users::Table, Users::CanRequestAdmin)),
395 SessionLookupIden::UserCanRequestAdmin,
396 )
397 .expr_as(
398 Expr::col((Users::Table, Users::IsGuest)),
399 SessionLookupIden::UserIsGuest,
400 )
401 .from(UserSessions::Table)
402 .inner_join(
403 Users::Table,
404 Expr::col((UserSessions::Table, UserSessions::UserId))
405 .equals((Users::Table, Users::UserId)),
406 )
407 .apply_filter(filter)
408 .generate_pagination(
409 (UserSessions::Table, UserSessions::UserSessionId),
410 pagination,
411 )
412 .build_sqlx(PostgresQueryBuilder);
413
414 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
415 .traced()
416 .fetch_all(&mut *self.conn)
417 .await?;
418
419 let page = pagination
420 .process(edges)
421 .try_map(BrowserSession::try_from)?;
422
423 Ok(page)
424 }
425
426 #[tracing::instrument(
427 name = "db.browser_session.count",
428 skip_all,
429 fields(
430 db.query.text,
431 ),
432 err,
433 )]
434 async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
435 let (sql, arguments) = sea_query::Query::select()
436 .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
437 .from(UserSessions::Table)
438 .apply_filter(filter)
439 .build_sqlx(PostgresQueryBuilder);
440
441 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
442 .traced()
443 .fetch_one(&mut *self.conn)
444 .await?;
445
446 count
447 .try_into()
448 .map_err(DatabaseError::to_invalid_operation)
449 }
450
451 #[tracing::instrument(
452 name = "db.browser_session.authenticate_with_password",
453 skip_all,
454 fields(
455 db.query.text,
456 %user_session.id,
457 %user_password.id,
458 user_session_authentication.id,
459 ),
460 err,
461 )]
462 async fn authenticate_with_password(
463 &mut self,
464 rng: &mut (dyn RngCore + Send),
465 clock: &dyn Clock,
466 user_session: &BrowserSession,
467 user_password: &Password,
468 ) -> Result<Authentication, Self::Error> {
469 let created_at = clock.now();
470 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
471 tracing::Span::current().record(
472 "user_session_authentication.id",
473 tracing::field::display(id),
474 );
475
476 sqlx::query!(
477 r#"
478 INSERT INTO user_session_authentications
479 (user_session_authentication_id, user_session_id, created_at, user_password_id)
480 VALUES ($1, $2, $3, $4)
481 "#,
482 Uuid::from(id),
483 Uuid::from(user_session.id),
484 created_at,
485 Uuid::from(user_password.id),
486 )
487 .traced()
488 .execute(&mut *self.conn)
489 .await?;
490
491 Ok(Authentication {
492 id,
493 created_at,
494 authentication_method: AuthenticationMethod::Password {
495 user_password_id: user_password.id,
496 },
497 })
498 }
499
500 #[tracing::instrument(
501 name = "db.browser_session.authenticate_with_upstream",
502 skip_all,
503 fields(
504 db.query.text,
505 %user_session.id,
506 %upstream_oauth_session.id,
507 user_session_authentication.id,
508 ),
509 err,
510 )]
511 async fn authenticate_with_upstream(
512 &mut self,
513 rng: &mut (dyn RngCore + Send),
514 clock: &dyn Clock,
515 user_session: &BrowserSession,
516 upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
517 ) -> Result<Authentication, Self::Error> {
518 let created_at = clock.now();
519 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
520 tracing::Span::current().record(
521 "user_session_authentication.id",
522 tracing::field::display(id),
523 );
524
525 sqlx::query!(
526 r#"
527 INSERT INTO user_session_authentications
528 (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
529 VALUES ($1, $2, $3, $4)
530 "#,
531 Uuid::from(id),
532 Uuid::from(user_session.id),
533 created_at,
534 Uuid::from(upstream_oauth_session.id),
535 )
536 .traced()
537 .execute(&mut *self.conn)
538 .await?;
539
540 Ok(Authentication {
541 id,
542 created_at,
543 authentication_method: AuthenticationMethod::UpstreamOAuth2 {
544 upstream_oauth2_session_id: upstream_oauth_session.id,
545 },
546 })
547 }
548
549 #[tracing::instrument(
550 name = "db.browser_session.get_last_authentication",
551 skip_all,
552 fields(
553 db.query.text,
554 %user_session.id,
555 ),
556 err,
557 )]
558 async fn get_last_authentication(
559 &mut self,
560 user_session: &BrowserSession,
561 ) -> Result<Option<Authentication>, Self::Error> {
562 let authentication = sqlx::query_as!(
563 AuthenticationLookup,
564 r#"
565 SELECT user_session_authentication_id
566 , created_at
567 , user_password_id
568 , upstream_oauth_authorization_session_id
569 FROM user_session_authentications
570 WHERE user_session_id = $1
571 ORDER BY created_at DESC
572 LIMIT 1
573 "#,
574 Uuid::from(user_session.id),
575 )
576 .traced()
577 .fetch_optional(&mut *self.conn)
578 .await?;
579
580 let Some(authentication) = authentication else {
581 return Ok(None);
582 };
583
584 let authentication = Authentication::try_from(authentication)?;
585 Ok(Some(authentication))
586 }
587
588 #[tracing::instrument(
589 name = "db.browser_session.record_batch_activity",
590 skip_all,
591 fields(
592 db.query.text,
593 ),
594 err,
595 )]
596 async fn record_batch_activity(
597 &mut self,
598 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
599 ) -> Result<(), Self::Error> {
600 activities.sort_unstable();
603 let mut ids = Vec::with_capacity(activities.len());
604 let mut last_activities = Vec::with_capacity(activities.len());
605 let mut ips = Vec::with_capacity(activities.len());
606
607 for (id, last_activity, ip) in activities {
608 ids.push(Uuid::from(id));
609 last_activities.push(last_activity);
610 ips.push(ip);
611 }
612
613 let res = sqlx::query!(
614 r#"
615 UPDATE user_sessions
616 SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
617 , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
618 FROM (
619 SELECT *
620 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
621 AS t(user_session_id, last_active_at, last_active_ip)
622 ) AS t
623 WHERE user_sessions.user_session_id = t.user_session_id
624 "#,
625 &ids,
626 &last_activities,
627 &ips as &[Option<IpAddr>],
628 )
629 .traced()
630 .execute(&mut *self.conn)
631 .await?;
632
633 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
634
635 Ok(())
636 }
637}