DirectLiNGAMを用いた因果探索
③ DirectLiNGAMを用いた因果探索
アルゴリズムを用いて因果探索を行い、生成データの因果関係を推定し、得られた結果をDAG(Directed Acyclic Graph: 有向非巡回グラフ)として可視化しています。以下に、各部分の詳細な解説と因果探索手法に関連する背景を示します。
import
numpy as np
import
pandas as pd
import
networkx as nx
import
matplotlib.pyplot as plt
from
lingam import DirectLiNGAM
# データ生成
#
Temperature(気温)を共通原因とし、IceCreamSales(アイスクリーム売上)とSwimmingAccidents(水泳事故件数)に影響を与える構造を持つデータを生成しています。各変数にはランダムノイズが加えられており、実際のデータセットに近い状況を模擬しています。
np.random.seed(42)
n_samples
= 300
temperature
= np.random.uniform(20, 40, n_samples)
ice_cream_sales
= temperature + np.random.normal(0, 2, n_samples)
swimming_accidents
= temperature + np.random.normal(0, 1, n_samples)
data
= pd.DataFrame({
'Temperature': temperature,
'IceCreamSales': ice_cream_sales,
'SwimmingAccidents': swimming_accidents
})
# 因果探索:DirectLiNGAMを使用してDAGを構築
#
DirectLiNGAMは、LiNGAM(Linear
Non-Gaussian Acyclic Model)は、線形で非ガウス分布を仮定した因果探索モデルで、独立成分分析(ICA: Independent Component Analysis)を利用して因果関係を推定します。得られる隣接行列(adjacency matrix)は、変数間の因果関係の強さを数値化しています。
model
= DirectLiNGAM()
model.fit(data)
# 因果関係の隣接行列を取得
# 隣接行列には、各変数間の因果関係を示す値が格納されています。非ゼロの値は、直接的な因果関係を意味します。例えば、Temperature -> IceCreamSales の重みが大きい場合、「Temperature」が「IceCreamSales」に強い因果的影響を与えることを示します。
adj_matrix
= model.adjacency_matrix_
var_names
= data.columns
#
DAGのエッジとスコアを抽出
# 隣接行列からエッジ(因果関係)を抽出し、各エッジに対応するスコア(因果効果の強さ)を記録します。DAGの構造が定義され、スコアは関係の強さを明示的に示します。
edges
= []
edge_scores
= {}
for
i, row in enumerate(adj_matrix):
for j, weight in enumerate(row):
if weight != 0: # 非ゼロの重みは因果関係を示す
edges.append((var_names[j], var_names[i]))
edge_scores[(var_names[j],
var_names[i])] = weight
# グラフの描画
plt.figure(figsize=(8,
6))
G
= nx.DiGraph(edges)
pos
= nx.spring_layout(G, seed=42) # ノードの配置を決定
edge_labels
= {(u, v): f"{edge_scores[(u, v)]:.2f}" for u, v in G.edges()}
#
DAGの描画
# NetworkXを利用してDAGを描画し、エッジにスコアを表示。「Temperature」が「IceCreamSales」と「SwimmingAccidents」に矢印を向けている場合、共通原因として適切な関係を示していると判断できます。
nx.draw(G,
pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=10,
font_weight='bold', arrowsize=20)
nx.draw_networkx_edge_labels(G,
pos, edge_labels=edge_labels, font_color='red', font_size=8)
plt.title("Causal
DAG with Causal Effect Scores")
plt.show()
出力結果