1use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21 DatabaseError,
22 filter::{Filter, StatementExt},
23 iden::Users,
24 pagination::QueryBuilderExt,
25 tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40 email::PgUserEmailRepository, password::PgUserPasswordRepository,
41 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43 terms::PgUserTermsRepository,
44};
45
46pub struct PgUserRepository<'c> {
48 conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52 pub fn new(conn: &'c mut PgConnection) -> Self {
54 Self { conn }
55 }
56}
57
58mod priv_ {
59 #![allow(missing_docs)]
62
63 use chrono::{DateTime, Utc};
64 use sea_query::enum_def;
65 use uuid::Uuid;
66
67 #[derive(Debug, Clone, sqlx::FromRow)]
68 #[enum_def]
69 pub(super) struct UserLookup {
70 pub(super) user_id: Uuid,
71 pub(super) username: String,
72 pub(super) created_at: DateTime<Utc>,
73 pub(super) locked_at: Option<DateTime<Utc>>,
74 pub(super) deactivated_at: Option<DateTime<Utc>>,
75 pub(super) can_request_admin: bool,
76 pub(super) is_guest: bool,
77 }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83 fn from(value: UserLookup) -> Self {
84 let id = value.user_id.into();
85 Self {
86 id,
87 username: value.username,
88 sub: id.to_string(),
89 created_at: value.created_at,
90 locked_at: value.locked_at,
91 deactivated_at: value.deactivated_at,
92 can_request_admin: value.can_request_admin,
93 is_guest: value.is_guest,
94 }
95 }
96}
97
98impl Filter for UserFilter<'_> {
99 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
100 sea_query::Condition::all()
101 .add_option(self.state().map(|state| {
102 match state {
103 mas_storage::user::UserState::Deactivated => {
104 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
105 }
106 mas_storage::user::UserState::Locked => {
107 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
108 }
109 mas_storage::user::UserState::Active => {
110 Expr::col((Users::Table, Users::LockedAt))
111 .is_null()
112 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
113 }
114 }
115 }))
116 .add_option(self.can_request_admin().map(|can_request_admin| {
117 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
118 }))
119 .add_option(
120 self.is_guest()
121 .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
122 )
123 }
124}
125
126#[async_trait]
127impl UserRepository for PgUserRepository<'_> {
128 type Error = DatabaseError;
129
130 #[tracing::instrument(
131 name = "db.user.lookup",
132 skip_all,
133 fields(
134 db.query.text,
135 user.id = %id,
136 ),
137 err,
138 )]
139 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
140 let res = sqlx::query_as!(
141 UserLookup,
142 r#"
143 SELECT user_id
144 , username
145 , created_at
146 , locked_at
147 , deactivated_at
148 , can_request_admin
149 , is_guest
150 FROM users
151 WHERE user_id = $1
152 "#,
153 Uuid::from(id),
154 )
155 .traced()
156 .fetch_optional(&mut *self.conn)
157 .await?;
158
159 let Some(res) = res else { return Ok(None) };
160
161 Ok(Some(res.into()))
162 }
163
164 #[tracing::instrument(
165 name = "db.user.find_by_username",
166 skip_all,
167 fields(
168 db.query.text,
169 user.username = username,
170 ),
171 err,
172 )]
173 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
174 let res = sqlx::query_as!(
178 UserLookup,
179 r#"
180 SELECT user_id
181 , username
182 , created_at
183 , locked_at
184 , deactivated_at
185 , can_request_admin
186 , is_guest
187 FROM users
188 WHERE LOWER(username) = LOWER($1)
189 "#,
190 username,
191 )
192 .traced()
193 .fetch_all(&mut *self.conn)
194 .await?;
195
196 match &res[..] {
197 [user] => Ok(Some(user.clone().into())),
199 [] => Ok(None),
201 list => {
202 if let Some(user) = list.iter().find(|user| user.username == username) {
205 Ok(Some(user.clone().into()))
206 } else {
207 Ok(None)
209 }
210 }
211 }
212 }
213
214 #[tracing::instrument(
215 name = "db.user.add",
216 skip_all,
217 fields(
218 db.query.text,
219 user.username = username,
220 user.id,
221 ),
222 err,
223 )]
224 async fn add(
225 &mut self,
226 rng: &mut (dyn RngCore + Send),
227 clock: &dyn Clock,
228 username: String,
229 ) -> Result<User, Self::Error> {
230 let created_at = clock.now();
231 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
232 tracing::Span::current().record("user.id", tracing::field::display(id));
233
234 let res = sqlx::query!(
235 r#"
236 INSERT INTO users (user_id, username, created_at)
237 VALUES ($1, $2, $3)
238 ON CONFLICT (username) DO NOTHING
239 "#,
240 Uuid::from(id),
241 username,
242 created_at,
243 )
244 .traced()
245 .execute(&mut *self.conn)
246 .await?;
247
248 DatabaseError::ensure_affected_rows(&res, 1)?;
251
252 Ok(User {
253 id,
254 username,
255 sub: id.to_string(),
256 created_at,
257 locked_at: None,
258 deactivated_at: None,
259 can_request_admin: false,
260 is_guest: false,
261 })
262 }
263
264 #[tracing::instrument(
265 name = "db.user.exists",
266 skip_all,
267 fields(
268 db.query.text,
269 user.username = username,
270 ),
271 err,
272 )]
273 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
274 let exists = sqlx::query_scalar!(
275 r#"
276 SELECT EXISTS(
277 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
278 ) AS "exists!"
279 "#,
280 username
281 )
282 .traced()
283 .fetch_one(&mut *self.conn)
284 .await?;
285
286 Ok(exists)
287 }
288
289 #[tracing::instrument(
290 name = "db.user.lock",
291 skip_all,
292 fields(
293 db.query.text,
294 %user.id,
295 ),
296 err,
297 )]
298 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
299 if user.locked_at.is_some() {
300 return Ok(user);
301 }
302
303 let locked_at = clock.now();
304 let res = sqlx::query!(
305 r#"
306 UPDATE users
307 SET locked_at = $1
308 WHERE user_id = $2
309 "#,
310 locked_at,
311 Uuid::from(user.id),
312 )
313 .traced()
314 .execute(&mut *self.conn)
315 .await?;
316
317 DatabaseError::ensure_affected_rows(&res, 1)?;
318
319 user.locked_at = Some(locked_at);
320
321 Ok(user)
322 }
323
324 #[tracing::instrument(
325 name = "db.user.unlock",
326 skip_all,
327 fields(
328 db.query.text,
329 %user.id,
330 ),
331 err,
332 )]
333 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
334 if user.locked_at.is_none() {
335 return Ok(user);
336 }
337
338 let res = sqlx::query!(
339 r#"
340 UPDATE users
341 SET locked_at = NULL
342 WHERE user_id = $1
343 "#,
344 Uuid::from(user.id),
345 )
346 .traced()
347 .execute(&mut *self.conn)
348 .await?;
349
350 DatabaseError::ensure_affected_rows(&res, 1)?;
351
352 user.locked_at = None;
353
354 Ok(user)
355 }
356
357 #[tracing::instrument(
358 name = "db.user.deactivate",
359 skip_all,
360 fields(
361 db.query.text,
362 %user.id,
363 ),
364 err,
365 )]
366 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
367 if user.deactivated_at.is_some() {
368 return Ok(user);
369 }
370
371 let deactivated_at = clock.now();
372 let res = sqlx::query!(
373 r#"
374 UPDATE users
375 SET deactivated_at = $2
376 WHERE user_id = $1
377 AND deactivated_at IS NULL
378 "#,
379 Uuid::from(user.id),
380 deactivated_at,
381 )
382 .traced()
383 .execute(&mut *self.conn)
384 .await?;
385
386 DatabaseError::ensure_affected_rows(&res, 1)?;
387
388 user.deactivated_at = Some(deactivated_at);
389
390 Ok(user)
391 }
392
393 #[tracing::instrument(
394 name = "db.user.reactivate",
395 skip_all,
396 fields(
397 db.query.text,
398 %user.id,
399 ),
400 err,
401 )]
402 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
403 if user.deactivated_at.is_none() {
404 return Ok(user);
405 }
406
407 let res = sqlx::query!(
408 r#"
409 UPDATE users
410 SET deactivated_at = NULL
411 WHERE user_id = $1
412 "#,
413 Uuid::from(user.id),
414 )
415 .traced()
416 .execute(&mut *self.conn)
417 .await?;
418
419 DatabaseError::ensure_affected_rows(&res, 1)?;
420
421 user.deactivated_at = None;
422
423 Ok(user)
424 }
425
426 #[tracing::instrument(
427 name = "db.user.set_can_request_admin",
428 skip_all,
429 fields(
430 db.query.text,
431 %user.id,
432 user.can_request_admin = can_request_admin,
433 ),
434 err,
435 )]
436 async fn set_can_request_admin(
437 &mut self,
438 mut user: User,
439 can_request_admin: bool,
440 ) -> Result<User, Self::Error> {
441 let res = sqlx::query!(
442 r#"
443 UPDATE users
444 SET can_request_admin = $2
445 WHERE user_id = $1
446 "#,
447 Uuid::from(user.id),
448 can_request_admin,
449 )
450 .traced()
451 .execute(&mut *self.conn)
452 .await?;
453
454 DatabaseError::ensure_affected_rows(&res, 1)?;
455
456 user.can_request_admin = can_request_admin;
457
458 Ok(user)
459 }
460
461 #[tracing::instrument(
462 name = "db.user.list",
463 skip_all,
464 fields(
465 db.query.text,
466 ),
467 err,
468 )]
469 async fn list(
470 &mut self,
471 filter: UserFilter<'_>,
472 pagination: mas_storage::Pagination,
473 ) -> Result<mas_storage::Page<User>, Self::Error> {
474 let (sql, arguments) = Query::select()
475 .expr_as(
476 Expr::col((Users::Table, Users::UserId)),
477 UserLookupIden::UserId,
478 )
479 .expr_as(
480 Expr::col((Users::Table, Users::Username)),
481 UserLookupIden::Username,
482 )
483 .expr_as(
484 Expr::col((Users::Table, Users::CreatedAt)),
485 UserLookupIden::CreatedAt,
486 )
487 .expr_as(
488 Expr::col((Users::Table, Users::LockedAt)),
489 UserLookupIden::LockedAt,
490 )
491 .expr_as(
492 Expr::col((Users::Table, Users::DeactivatedAt)),
493 UserLookupIden::DeactivatedAt,
494 )
495 .expr_as(
496 Expr::col((Users::Table, Users::CanRequestAdmin)),
497 UserLookupIden::CanRequestAdmin,
498 )
499 .expr_as(
500 Expr::col((Users::Table, Users::IsGuest)),
501 UserLookupIden::IsGuest,
502 )
503 .from(Users::Table)
504 .apply_filter(filter)
505 .generate_pagination((Users::Table, Users::UserId), pagination)
506 .build_sqlx(PostgresQueryBuilder);
507
508 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
509 .traced()
510 .fetch_all(&mut *self.conn)
511 .await?;
512
513 let page = pagination.process(edges).map(User::from);
514
515 Ok(page)
516 }
517
518 #[tracing::instrument(
519 name = "db.user.count",
520 skip_all,
521 fields(
522 db.query.text,
523 ),
524 err,
525 )]
526 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
527 let (sql, arguments) = Query::select()
528 .expr(Expr::col((Users::Table, Users::UserId)).count())
529 .from(Users::Table)
530 .apply_filter(filter)
531 .build_sqlx(PostgresQueryBuilder);
532
533 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
534 .traced()
535 .fetch_one(&mut *self.conn)
536 .await?;
537
538 count
539 .try_into()
540 .map_err(DatabaseError::to_invalid_operation)
541 }
542
543 #[tracing::instrument(
544 name = "db.user.acquire_lock_for_sync",
545 skip_all,
546 fields(
547 db.query.text,
548 user.id = %user.id,
549 ),
550 err,
551 )]
552 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
553 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
561
562 sqlx::query!(
565 r#"
566 SELECT pg_advisory_xact_lock($1)
567 "#,
568 lock_id,
569 )
570 .traced()
571 .execute(&mut *self.conn)
572 .await?;
573
574 Ok(())
575 }
576}