Files
server/src/Infrastructure.Dapper/Repositories/UserRepository.cs

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

305 lines
9.9 KiB
C#
Raw Normal View History

using System.Data;
using Bit.Core;
using Bit.Core.Auth.UserFeatures.UserKey;
using Bit.Core.Entities;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Dapper;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.Data.SqlClient;
namespace Bit.Infrastructure.Dapper.Repositories;
2022-08-29 16:06:55 -04:00
public class UserRepository : Repository<User, Guid>, IUserRepository
{
private readonly IDataProtector _dataProtector;
public UserRepository(
GlobalSettings globalSettings,
IDataProtectionProvider dataProtectionProvider)
: this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString)
{
_dataProtector = dataProtectionProvider.CreateProtector(Constants.DatabaseFieldProtectorPurpose);
}
public UserRepository(string connectionString, string readOnlyConnectionString)
: base(connectionString, readOnlyConnectionString)
{ }
public override async Task<User> GetByIdAsync(Guid id)
{
var user = await base.GetByIdAsync(id);
UnprotectData(user);
return user;
}
public async Task<User> GetByEmailAsync(string email)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<User>(
$"[{Schema}].[{Table}_ReadByEmail]",
new { Email = email },
commandType: CommandType.StoredProcedure);
UnprotectData(results);
2018-03-21 16:24:10 -04:00
return results.SingleOrDefault();
}
2022-08-29 16:06:55 -04:00
}
2018-03-21 16:24:10 -04:00
2017-08-22 15:27:29 -04:00
public async Task<User> GetBySsoUserAsync(string externalId, Guid? organizationId)
2022-08-29 16:06:55 -04:00
{
2017-08-22 15:27:29 -04:00
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<User>(
$"[{Schema}].[{Table}_ReadBySsoUserOrganizationIdExternalId]",
new { OrganizationId = organizationId, ExternalId = externalId },
commandType: CommandType.StoredProcedure);
2022-08-29 16:06:55 -04:00
UnprotectData(results);
2017-08-22 15:27:29 -04:00
return results.SingleOrDefault();
}
2022-08-29 16:06:55 -04:00
}
2017-08-22 15:27:29 -04:00
2017-03-04 21:28:41 -05:00
public async Task<UserKdfInformation> GetKdfInformationByEmailAsync(string email)
2022-08-29 16:06:55 -04:00
{
2017-03-04 21:28:41 -05:00
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<UserKdfInformation>(
$"[{Schema}].[{Table}_ReadKdfByEmail]",
new { Email = email },
commandType: CommandType.StoredProcedure);
2022-08-29 16:06:55 -04:00
2017-03-04 21:28:41 -05:00
return results.SingleOrDefault();
}
2022-08-29 16:06:55 -04:00
}
2017-03-04 21:28:41 -05:00
public async Task<ICollection<User>> SearchAsync(string email, int skip, int take)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ReadOnlyConnectionString))
{
var results = await connection.QueryAsync<User>(
$"[{Schema}].[{Table}_Search]",
new { Email = email, Skip = skip, Take = take },
commandType: CommandType.StoredProcedure,
commandTimeout: 120);
2022-08-29 16:06:55 -04:00
UnprotectData(results);
return results.ToList();
}
2022-08-29 16:06:55 -04:00
}
public async Task<ICollection<User>> GetManyByPremiumAsync(bool premium)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<User>(
"[dbo].[User_ReadByPremium]",
2017-07-10 22:08:52 -04:00
new { Premium = premium },
commandType: CommandType.StoredProcedure);
2022-08-29 16:06:55 -04:00
UnprotectData(results);
return results.ToList();
}
2022-08-29 16:06:55 -04:00
}
public async Task<string> GetPublicKeyAsync(Guid id)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<string>(
$"[{Schema}].[{Table}_ReadPublicKeyById]",
2017-07-10 22:08:52 -04:00
new { Id = id },
commandType: CommandType.StoredProcedure);
2022-08-29 16:06:55 -04:00
return results.SingleOrDefault();
}
2022-08-29 16:06:55 -04:00
}
2022-08-29 14:53:16 -04:00
public async Task<DateTime> GetAccountRevisionDateAsync(Guid id)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ReadOnlyConnectionString))
{
var results = await connection.QueryAsync<DateTime>(
$"[{Schema}].[{Table}_ReadAccountRevisionDateById]",
new { Id = id },
commandType: CommandType.StoredProcedure);
2022-08-29 16:06:55 -04:00
return results.SingleOrDefault();
}
2022-08-29 16:06:55 -04:00
}
2017-07-10 16:21:18 -04:00
public override async Task<User> CreateAsync(User user)
{
await ProtectDataAndSaveAsync(user, async () => await base.CreateAsync(user));
return user;
}
public override async Task ReplaceAsync(User user)
2022-08-29 16:06:55 -04:00
{
await ProtectDataAndSaveAsync(user, async () => await base.ReplaceAsync(user));
2022-08-29 16:06:55 -04:00
}
2017-07-10 22:08:52 -04:00
public override async Task DeleteAsync(User user)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ConnectionString))
2017-07-10 16:21:18 -04:00
{
await connection.ExecuteAsync(
$"[{Schema}].[{Table}_DeleteById]",
2017-07-10 22:08:52 -04:00
new { Id = user.Id },
2017-07-10 16:21:18 -04:00
commandType: CommandType.StoredProcedure,
commandTimeout: 180);
}
2022-08-29 16:06:55 -04:00
}
public async Task UpdateStorageAsync(Guid id)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ConnectionString))
{
await connection.ExecuteAsync(
$"[{Schema}].[{Table}_UpdateStorage]",
new { Id = id },
commandType: CommandType.StoredProcedure,
commandTimeout: 180);
}
2022-08-29 16:06:55 -04:00
}
public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ConnectionString))
{
await connection.ExecuteAsync(
$"[{Schema}].[User_UpdateRenewalReminderDate]",
new { Id = id, RenewalReminderDate = renewalReminderDate },
commandType: CommandType.StoredProcedure);
}
2022-08-29 16:06:55 -04:00
}
2022-08-29 14:53:16 -04:00
/// <inheritdoc />
public async Task UpdateUserKeyAndEncryptedDataAsync(
User user,
IEnumerable<UpdateEncryptedDataForKeyRotation> updateDataActions)
{
await using var connection = new SqlConnection(ConnectionString);
connection.Open();
await using var transaction = connection.BeginTransaction();
try
{
// Update user
await using (var cmd = new SqlCommand("[dbo].[User_UpdateKeys]", connection, transaction))
{
cmd.CommandType = CommandType.StoredProcedure;
cmd.Parameters.Add("@Id", SqlDbType.UniqueIdentifier).Value = user.Id;
cmd.Parameters.Add("@SecurityStamp", SqlDbType.NVarChar).Value = user.SecurityStamp;
cmd.Parameters.Add("@Key", SqlDbType.VarChar).Value = user.Key;
cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value =
string.IsNullOrWhiteSpace(user.PrivateKey) ? DBNull.Value : user.PrivateKey;
cmd.Parameters.Add("@RevisionDate", SqlDbType.DateTime2).Value = user.RevisionDate;
cmd.Parameters.Add("@AccountRevisionDate", SqlDbType.DateTime2).Value =
user.AccountRevisionDate;
cmd.Parameters.Add("@LastKeyRotationDate", SqlDbType.DateTime2).Value =
user.LastKeyRotationDate;
cmd.ExecuteNonQuery();
}
// Update re-encrypted data
foreach (var action in updateDataActions)
{
await action(transaction);
}
transaction.Commit();
}
catch
{
transaction.Rollback();
throw;
}
}
public async Task<IEnumerable<User>> GetManyAsync(IEnumerable<Guid> ids)
2022-08-29 16:06:55 -04:00
{
using (var connection = new SqlConnection(ReadOnlyConnectionString))
{
var results = await connection.QueryAsync<User>(
$"[{Schema}].[{Table}_ReadByIds]",
new { Ids = ids.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
2022-08-29 16:06:55 -04:00
UnprotectData(results);
return results.ToList();
}
}
private async Task ProtectDataAndSaveAsync(User user, Func<Task> saveTask)
{
if (user == null)
{
await saveTask();
return;
}
// Capture original values
var originalMasterPassword = user.MasterPassword;
var originalKey = user.Key;
// Protect values
if (!user.MasterPassword?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false)
{
user.MasterPassword = string.Concat(Constants.DatabaseFieldProtectedPrefix,
_dataProtector.Protect(user.MasterPassword));
}
if (!user.Key?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false)
{
user.Key = string.Concat(Constants.DatabaseFieldProtectedPrefix,
_dataProtector.Protect(user.Key));
}
// Save
await saveTask();
// Restore original values
user.MasterPassword = originalMasterPassword;
user.Key = originalKey;
}
private void UnprotectData(User user)
{
if (user == null)
{
return;
}
if (user.MasterPassword?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false)
{
user.MasterPassword = _dataProtector.Unprotect(
user.MasterPassword.Substring(Constants.DatabaseFieldProtectedPrefix.Length));
}
if (user.Key?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false)
{
user.Key = _dataProtector.Unprotect(
user.Key.Substring(Constants.DatabaseFieldProtectedPrefix.Length));
}
}
private void UnprotectData(IEnumerable<User> users)
{
if (users == null)
{
return;
}
foreach (var user in users)
{
UnprotectData(user);
}
}
}