CORA Dataloader 分析

from .dataset_loader import DatasetLoader

class Cora(
    DatasetLoader,
    name="Cora",
    directory_name="cora",
    url="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    url_archive_format="gztar",
    expected_files=["cora.cites", "cora.content"],
    description="The Cora dataset consists of 2708 scientific publications classified into one of seven classes. "
    "The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector "
    "indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.",
    source="https://linqs.soe.ucsc.edu/data",
):

    _NUM_FEATURES = 1433

    def load(
        self,
        directed=False,
        largest_connected_component_only=False,
        subject_as_feature=False,
        edge_weights=None,
        str_node_ids=False,
    ):
        """
        Load this dataset into a homogeneous graph that is directed or undirected, downloading it if
        required.

        The node feature vectors are included, and the edges are treated as directed or undirected
        depending on the ``directed`` parameter.

        Args:
            directed (bool): if True, return a directed graph, otherwise return an undirected one.
            largest_connected_component_only (bool): if True, returns only the largest connected
                component, not the whole graph.
            edge_weights (callable, optional): a function that accepts three parameters: an
                unweighted StellarGraph containing node features, a Pandas Series of the node
                labels, a Pandas DataFrame of the edges (with `source` and `target` columns). It
                should return a sequence of numbers (e.g. a 1D NumPy array) of edge weights for each
                edge in the DataFrame.
            str_node_ids (bool): if True, load the node IDs as strings, rather than integers.
            subject_as_feature (bool): if True, the subject for each paper (node) is included in the
                node features, one-hot encoded (the subjects are still also returned as a Series).

        Returns:
            A tuple where the first element is the :class:`.StellarGraph` object (or
            :class:`.StellarDiGraph`, if ``directed == True``) with the nodes, node feature vectors
            and edges, and the second element is a pandas Series of the node subject class labels.
        """
        nodes_dtype = str if str_node_ids else int

        return _load_cora_or_citeseer(
            self,
            directed,
            largest_connected_component_only,
            subject_as_feature,
            edge_weights,
            nodes_dtype,
        )

def _load_cora_or_citeseer(
            dataset,
            directed,
            largest_connected_component_only,
            subject_as_feature,
            edge_weights,
            nodes_dtype,
    ):
    assert isinstance(dataset, (Cora, CiteSeer))

    if nodes_dtype is None:
        nodes_dtype = dataset._NODES_DTYPE

    dataset.download()

    # expected_files should be in this order
    cites, content = [dataset._resolve_path(name) for name in dataset.expected_files]

    feature_names = ["w_{}".format(ii) for ii in range(dataset._NUM_FEATURES)]
    subject = "subject"
    if subject_as_feature:
        feature_names.append(subject)
        column_names = feature_names
    else:
        column_names = feature_names + [subject]

    node_data = pd.read_csv(
        content, sep="\t", header=None, names=column_names, dtype={0: nodes_dtype}
    )

    edgelist = pd.read_csv(
        cites, sep="\t", header=None, names=["target", "source"], dtype=nodes_dtype
    )

    valid_source = node_data.index.get_indexer(edgelist.source) >= 0
    valid_target = node_data.index.get_indexer(edgelist.target) >= 0
    edgelist = edgelist[valid_source & valid_target]

    subjects = node_data[subject]

    cls = StellarDiGraph if directed else StellarGraph

    features = node_data[feature_names]
    if subject_as_feature:
        # one-hot encode the subjects
        features = pd.get_dummies(features, columns=[subject])

    graph = cls({"paper": features}, {"cites": edgelist})

    if edge_weights is not None:
        # A weighted graph means computing a second StellarGraph after using the unweighted one to
        # compute the weights.
        edgelist["weight"] = edge_weights(graph, subjects, edgelist)
        graph = cls({"paper": node_data[feature_names]}, {"cites": edgelist})

    if largest_connected_component_only:
        cc_ids = next(graph.connected_components())
        return graph.subgraph(cc_ids), subjects[cc_ids]

    return graph, subjects

posted @ 2023-06-20 18:36  ZZX11  阅读(26)  评论(0编辑  收藏  举报