milvus源码vector_index中Dataset数据类的使用
1、dataset数据类定义
using Value = std::any; using ValuePtr = std::shared_ptr<Value>; class Dataset { public: Dataset() = default; template <typename T> void Set(const std::string& k, T&& v) { std::lock_guard<std::mutex> lk(mutex_); data_[k] = std::make_shared<Value>(std::forward<T>(v)); } template <typename T> T Get(const std::string& k) { std::lock_guard<std::mutex> lk(mutex_); try { return std::any_cast<T>(*(data_.at(k))); } catch (...) { throw std::logic_error("Can't find this key"); } } const std::map<std::string, ValuePtr>& data() const { return data_; } private: std::mutex mutex_; std::map<std::string, ValuePtr> data_; }; using DatasetPtr = std::shared_ptr<Dataset>;
2、定义一个宏解析数据类
能够直接解析dataset_ptr,拿到dim、rows和数据指针p_data
#define GETTENSOR(dataset_ptr) \ int64_t dim = dataset_ptr->Get<int64_t>(meta::DIM); \ int64_t rows = dataset_ptr->Get<int64_t>(meta::ROWS); \ const void* p_data = dataset_ptr->Get<const void*>(meta::TENSOR);
3、使用示例
调用宏之后可以拿到解析后的数据dim、rows和数据指针p_data
void IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { GETTENSOR(dataset_ptr) int64_t nlist = config[IndexParams::nlist].get<int64_t>(); faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>()); faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type); index->own_fields = true; index->train(rows, reinterpret_cast<const float*>(p_data)); index_ = index; }