ひっそりと生きるプログラマのブログ

日頃気になった事なりを書き留めるブログです。関心ごとは多くもう少し更新頻度を上げたいところです。

【Entity Framework】テーブルの主キーの値を渡して、レコードを取得する方法(ナビゲーションの値も含む)

以下のような実装で取得可能。

次の実装は呼び出し元

SampleTable というテーブルがあり、
主キーに相当するプロパティが SampleId になります。
SampleSubTable と SampleTable はナビゲーションプロパティが定義しています。

static void Main(string[] args)
{
    SampleTable data = null;
    using (var context = new Sample01Entities())
    {
        context.Configuration.ProxyCreationEnabled = false;
        data = GetSampleTableData(context, 2);
    }
    Console.WriteLine(data.SampleId);
    Console.WriteLine(data.SampleName);
    Console.WriteLine(data.SampleSubTable.Count);
    Console.ReadLine();
}

private static SampleTable GetSampleTableData(
    DbContext context, int key1)
{
    return DbContextUtility.GetData<SampleTable>(
        context,
        new SampleTable()
        {
            SampleId = key1
        });
}

次の実装は呼び出し先
リファクタリングの余地はあるが、
実現したいことはできていると思います。

public static class DbContextUtility
{
    #region EntityType
    private static Dictionary<Type, EntityType> entityTypes =
        new Dictionary<Type, EntityType>();

    private static EntityType GetEntityType(DbContext context, Type type)
    {
        EntityType result;
        if (entityTypes.TryGetValue(type, out result)) return result;
        var objectContext = ((IObjectContextAdapter)context).ObjectContext;
        var metadata = objectContext.MetadataWorkspace;
        result = metadata.GetItem<EntityType>(type.FullName, DataSpace.OSpace);
        entityTypes.Add(type, result);
        return result;
    }
    #endregion

    public static T GetData<T>(
        DbContext context, T entity) where T : class
    {
        var type = typeof(T);
        var entityType = GetEntityType(context, type);
        var keyProperties = entityType.KeyMembers.Select(k => type.GetProperty(k.Name));
        var values = keyProperties.Select(m => m.GetValue(entity)).ToArray();
        var result = (T)context.Set<T>().Find(values);
        GetEntity(context, result, new HashSet<object>());
        return (T)result;
    }

    private static void GetEntity(
        DbContext context, object entity, HashSet<object> values)
    {
        if (values.Contains(entity)) return;
        values.Add(entity);
        var type = entity.GetType();
        var entityType = GetEntityType(context, type);
        var navproperties = entityType.NavigationProperties.Select(
            m => entity.GetType().GetProperty(m.Name));
        foreach (var nav in navproperties)
        {
            if (nav.PropertyType.IsGenericType)
            {
                context.Entry(entity).Collection(nav.Name).Load();
                var enumerable = (IEnumerable)nav.GetValue(entity);
                foreach(var e in enumerable)
                    GetEntity(context, e, values);
            }
            else
            {
                context.Entry(entity).Reference(nav.Name).Load();
                var value = nav.GetValue(entity);
                GetEntity(context, value, values);
            }
        }
    }
}