Database spoofing (EF + nUnit)

Foremost, it is worth mentioning that this note is a cheat sheet on the topic, and not an explanation and analysis of details. Therefore, if you have a desire to figure out exactly what's what, you will either have to read the code yourself, or contact me for clarification.

To implement the following principle, you need to install the recommended libraries.

Setting up the project under test

In order to reduce the code listing in the article, we will work with only one table, let's call it ... Student.

In the application, we have, respectively, the class

public class Student : BaseEntity
{
    public string Name {get;set;}
}

and

public class BaseEntity
{
    public int Id {get;set;}
}

We will talk about the nature of BaseEntity a little later, but for now we will restrict ourselves to the fact that in the vast majority of database models, the key identifier is a numeric Id. If your case goes beyond the boundaries of this situation... In this article, we will not consider this option (if you can’t figure it out, you can write to me via the feedback form), focusing on the most common situations.

The context class doesn't change:

public class DataContext : DbContext, IDataContext
{
    public DataContext() : base("DataContext")
    {
    }
    
    public DbSet<Student> Students { get; set; }
}

An important difference from the usual implementation is IDataContext:

public interface IDataContext : IDisposable
{
    DbSet<Student> Students { get; set; }
    Task<int> SaveChangesAsync();
    int SaveChanges();
}

It is this interface that will be used in all DI in the application, I propose to forget about the direct use of DataContext, we will talk more about DI for testing purposes, I believe in the next articles.

That's it with the working project, let's move on to the testing code and here we have to write (and you'd rather copy/paste :smiley: ) a lot more...

Let's create a Context folder in the root of the (testing) project and put the following classes there:

internal class FakeContext : IDataContext
{
    public FakeContext()
    {
        Students = new FakeDbSet<Student>();
    }

    public void Dispose()
    {
    }
   
    public DbSet<Student> Students { get; set; }

    public int SaveChanges()
    {
        return 1;
    }

    public async Task<int> SaveChangesAsync()
    {
        return 1;
    }
}
internal class TestDbAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
{
    private readonly IQueryProvider _inner;

    internal TestDbAsyncQueryProvider(IQueryProvider inner)
    {
        _inner = inner;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        return new TestDbAsyncEnumerable<TEntity>(expression);
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        return new TestDbAsyncEnumerable<TElement>(expression);
    }

    public object Execute(Expression expression)
    {
        return _inner.Execute(expression);
    }

    public TResult Execute<TResult>(Expression expression)
    {
        return _inner.Execute<TResult>(expression);
    }

    public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
    {
        return Task.FromResult(Execute(expression));
    }

    public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
    {
        return Task.FromResult(Execute<TResult>(expression));
    }
}
internal class FakeDbSet<TEntity> : DbSet<TEntity>, IQueryable
    , IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>
    where TEntity : BaseEntity
{
    ObservableCollection<TEntity> _data;
    IQueryable _query;

    public FakeDbSet()
    {
        _data = new ObservableCollection<TEntity>();
        _query = _data.AsQueryable();
    }

    public override TEntity Add(TEntity item)
    {
        item.Id = _data.Count;
        _data.Add(item);
        return item;
    }

    public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
    {
        foreach (var item in entities)
        {
            item.Id = _data.Count;
            _data.Add(item);
        }
        return entities;
    }

    public override TEntity Remove(TEntity item)
    {
        _data.Remove(item);
        return item;
    }

    public override TEntity Attach(TEntity item)
    {
        _data.Add(item);
        return item;
    }

    public override TEntity Create()
    {
        return Activator.CreateInstance<TEntity>();
    }

    public override TDerivedEntity Create<TDerivedEntity>()
    {
        return Activator.CreateInstance<TDerivedEntity>();
    }

    public override ObservableCollection<TEntity> Local
    {
        get { return _data; }
    }

    Type IQueryable.ElementType
    {
        get { return _query.ElementType; }
    }

    Expression IQueryable.Expression
    {
        get { return _query.Expression; }
    }

    IQueryProvider IQueryable.Provider
    {
        get { return new TestDbAsyncQueryProvider<TEntity>(_query.Provider); }
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return _data.GetEnumerator();
    }

    IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
    {
        return _data.GetEnumerator();
    }

    IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
    {
        return new TestDbAsyncEnumerator<TEntity>(_data.GetEnumerator());
    }
}

internal class TestDbAsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T>
{
    public TestDbAsyncEnumerable(IEnumerable<T> enumerable)
        : base(enumerable)
    { }

    public TestDbAsyncEnumerable(Expression expression)
        : base(expression)
    { }

    public IDbAsyncEnumerator<T> GetAsyncEnumerator()
    {
        return new TestDbAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
    }

    IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
    {
        return GetAsyncEnumerator();
    }

    IQueryProvider IQueryable.Provider
    {
        get { return new TestDbAsyncQueryProvider<T>(this); }
    }
}

internal class TestDbAsyncEnumerator<T> : IDbAsyncEnumerator<T>
{
    private readonly IEnumerator<T> _inner;

    public TestDbAsyncEnumerator(IEnumerator<T> inner)
    {
        _inner = inner;
    }

    public void Dispose()
    {
        _inner.Dispose();
    }

    public Task<bool> MoveNextAsync(CancellationToken cancellationToken)
    {
        return Task.FromResult(_inner.MoveNext());
    }

    public T Current
    {
        get { return _inner.Current; }
    }

    object IDbAsyncEnumerator.Current
    {
        get { return Current; }
    }
}

We will use FakeContext in the testing code to fake our database, we will have to add all database entities to it (in our case, we managed one). It's useless to talk about TestDbAsyncQueryProvider, copied - it works. FakeDbSet... There is also no point in discussing it - in everyday situations it just exists, it allows you to write tests and this is already happiness) Difficulties begin in cases where our models do not correspond to the logic of the BaseEntity class. As already mentioned, the solution of such problems is beyond the scope of this opus.

Let's take a look at an example use case for a fictitious StudentService.

   [TestFixture]
    public class StudentServiceTests
    {
        private IDataContext _context;
        private StudentService _service;

        [SetUp]
        public void SetUp()
        {
            _context = new FakeContext();
            _service = new StudentService(_context);
        }
        
        [Test, AutoData, CustomFixture]
        public void GetStudents_Success(List<Student> students)
        {
            const int count = students.Count;
            _context.Students.AddRange(students);
            //
            var result = _service.Get();
            //
            Assert.True(result.Count == count);
        }
    }