GAN(Generative Adversarial Network)を用いた因果探索の実装

GAN(Generative Adversarial Network)を用いた因果探索の実装

ードの目的

  • 因果構造データの生成 サンプルデータには、X1 -> X2 -> X3 の因果構造を仮定しています。
  • GANを使用したデータ生成 GANで生成したデータが、元の因果構造をどの程度再現できるか確認します。
  • DAGの描画 真の因果構造(DAG)を視覚化。

背景

  • GANを用いる因果探索では、生成モデルがデータの分布を再現し、それを基に因果構造を明らかにするアプローチが採られます。
  • 生成されたデータを更に解析することで、新たな因果関係の発見や因果ネットワークの推定が可能です。

このコードは基礎的な例で、実際の応用では、生成データと実データ間の因果一致をさらに詳しく評価する必要があります。

引用アドレス https://deus-ex-machina-ism.com/?p=52535&amp=1

GANGenerative Adversarial Network)を用いた因果探索の実装は、生成モデルと因果関係推定の組み合わせを活用する高度なアプローチです。以下に、シンプルなGANを利用して因果構造の学習を行う例を示します。この例では、因果構造を暗黙的に学習し、生成されるデータを基に解析します。

コード例 GANで因果探索

以下は、GANを活用して因果グラフを推定する簡単な例です。

python

import numpy as np

import tensorflow as tf

from tensorflow.keras import layers

import networkx as nx

import matplotlib.pyplot as plt

 

# データ生成(因果構造に基づくサンプルデータ)

def generate_data(n_samples):

    x1 = np.random.uniform(-1, 1, n_samples)

    x2 = 2 * x1 + np.random.normal(scale=0.1, size=n_samples)  # x1 -> x2

    x3 = 0.5 * x1 - 0.5 * x2 + np.random.normal(scale=0.1, size=n_samples)  # x1, x2 -> x3

    data = np.column_stack([x1, x2, x3])

    return data

 

n_samples = 1000

data = generate_data(n_samples)

 

# GANGeneratorの構築

def build_generator(input_dim, output_dim):

    model = tf.keras.Sequential([

        layers.Input(shape=(input_dim,)),

        layers.Dense(64, activation="relu"),

        layers.Dense(output_dim)

    ])

    return model

 

# GANDiscriminatorの構築

def build_discriminator(input_dim):

    model = tf.keras.Sequential([

        layers.Input(shape=(input_dim,)),

        layers.Dense(64, activation="relu"),

        layers.Dense(1, activation="sigmoid")

    ])

    return model

 

# モデル構築

input_dim = data.shape[1]

generator = build_generator(input_dim, input_dim)

discriminator = build_discriminator(input_dim)

 

# GANのトレーニング設定

discriminator.compile(optimizer="adam", loss="binary_crossentropy")

gan = tf.keras.Sequential([generator, discriminator])

discriminator.trainable = False

gan.compile(optimizer="adam", loss="binary_crossentropy")

 

# トレーニング

epochs = 5000

batch_size = 32

for epoch in range(epochs):

    # 真のデータと生成データ

    idx = np.random.randint(0, data.shape[0], batch_size)

    real_data = data[idx]

    noise = np.random.normal(size=(batch_size, input_dim))

    fake_data = generator.predict(noise, verbose=0)

 

    # Discriminatorの訓練

    d_loss_real = discriminator.train_on_batch(real_data, np.ones((batch_size, 1)))

    d_loss_fake = discriminator.train_on_batch(fake_data, np.zeros((batch_size, 1)))

 

    # Generatorの訓練

    noise = np.random.normal(size=(batch_size, input_dim))

    g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

 

    if epoch % 500 == 0:

        print(f"Epoch {epoch}: D loss real = {d_loss_real:.4f}, D loss fake = {d_loss_fake:.4f}, G loss = {g_loss:.4f}")

 

# 因果グラフの描画

def plot_causal_graph():

    G = nx.DiGraph()

    G.add_edges_from([("X1", "X2"), ("X1", "X3"), ("X2", "X3")])  # 真の因果構造

    plt.figure(figsize=(6, 4))

    pos = nx.spring_layout(G)

    nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=2000, font_size=10, font_color="black")

    plt.title("True Causal Graph")

    plt.show()

 

plot_causal_graph()

出力結果

/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py:82: UserWarning: The model does not have any trainable weights.
  warnings.warn("The model does not have any trainable weights.")
Epoch 0: D loss real = 0.7863, D loss fake = 0.7490, G loss = 0.6795
Epoch 500: D loss real = 1.4664, D loss fake = 1.4686, G loss = 0.2282
Epoch 1000: D loss real = 2.1297, D loss fake = 2.1314, G loss = 0.1203
Epoch 1500: D loss real = 2.5706, D loss fake = 2.5720, G loss = 0.0813
Epoch 2000: D loss real = 2.8968, D loss fake = 2.8980, G loss = 0.0613
Epoch 2500: D loss real = 3.1566, D loss fake = 3.1576, G loss = 0.0492
Epoch 3000: D loss real = 3.3737, D loss fake = 3.3745, G loss = 0.0411
Epoch 3500: D loss real = 3.5629, D loss fake = 3.5637, G loss = 0.0353
Epoch 4000: D loss real = 3.7301, D loss fake = 3.7307, G loss = 0.0309
Epoch 4500: D loss real = 3.8828, D loss fake = 3.8834, G loss = 0.0275



このブログの人気の投稿

片貝の四尺玉は世界一を連呼する『片貝賛歌~希望の花~』を作詞しました!!

小論文 統計的因果推論の現場適用による排泄ケアの展望

論文 排泄ケアにおける尊厳の保持と社会システムの課題 ~「おむつ」をめぐる心理的・文化的考察~