[源码解析] TensorFlow 分布式环境(4) --- WorkerCache
[源码解析] TensorFlow 分布式环境(4) --- WorkerCache
[toc]
我们接下来介绍缓存机制。为什么要缓存?因为集群内部有众多 worker。在 Master 与 Worker 之间,Worker 和 Worker 之间都需要交互,所以有必要把 Worker 和其 Grpc 通道都缓存起来。可以说,在 TensorFlow 分布式环境下处处可见缓存的使用。
本系列其他文章是:
[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"
[源码解析] TensorFlow 分布式环境(1) --- 总体架构
[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑
[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑
1. WorkerCache
WorkerCache 的作用就是获取 WorkerInterface 实例,WorkerInterface 实例可以访问远端 WorkerSerivice 服务。WorkerInterface 实例的典型就是 GrpcRemoteWorker。
1.1 如何使用
前面初始化 MasterEnv 时,WorkerCacheFactory 被配置到 master_env_.worker_cache_factory 之中。
master_env_.worker_cache_factory = [this](const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache) { return WorkerCacheFactory(options, worker_cache); };
后续在 Master::CreateSession 之中,有如下删减版代码,从中可以知道如何从工厂类之中获取 worker_cache(WorkerCacheInterface实例),以及后续如何使用 worker_cache 进行操作。
void Master::CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done) { SchedClosure([this, req, resp, done]() { // 配置option WorkerCacheFactoryOptions worker_cache_factory_options; worker_cache_factory_options.protocol = &grpc_protocol; worker_cache_factory_options.rpc_options = &req->config().rpc_options(); // 建立 worker_cache // Create the worker cache from the computed server_def. status = env_->worker_cache_factory(worker_cache_factory_options, &worker_cache); // 使用 worker_cache 来完成后续操作 status = DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, worker_cache, remote_devices.get()); }); }
1.2 配置
WorkerCacheFactoryOptions 等价于 ServerDef,它包含 ClusterDef,job_name,task_index 等信息。
// Options passed to the worker_cache_factory function. struct WorkerCacheFactoryOptions { const ClusterDef* cluster_def = nullptr; const string* job_name = nullptr; int task_index; const string* protocol = nullptr; const RPCOptions* rpc_options = nullptr; WorkerCacheFactoryOptions() {} // Construct from a ServerDef proto. // // Note: server_def must outlive WorkerCacheFactoryOptions! WorkerCacheFactoryOptions(const ServerDef& server_def) { if (server_def.has_cluster() && !server_def.job_name().empty()) { cluster_def = &server_def.cluster(); job_name = &server_def.job_name(); task_index = server_def.task_index(); protocol = &server_def.protocol(); rpc_options = &server_def.default_session_config().rpc_options(); } } };
1.3 工厂类
WorkerCacheFactory 是一个函数,其作用如下:
- 使用 ParseChannelSpec 来得到 GrpcChannelSpec 实例,GrpcChannelSpec 等价于 ClusterSpec,其包含集群基本配置信息。
- 使用 NewGrpcChannelCache 拿到一个GrpcChannelCache channel_cache。这里使用到了 GetChannelCreationFunction。
- 使用 NewGrpcWorkerCacheWithLocalWorker(channel_cache) 得到 worker_cache。
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache) { // 得到 GrpcChannelSpec GrpcChannelSpec channel_spec; TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); // 得到 GrpcChannelCache std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache( channel_spec, GetChannelCreationFunction(), *options.rpc_options)); string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", "/task:", options.task_index); const string host_port = channel_cache->TranslateTask(name_prefix); int requested_port; auto colon_index = host_port.find_last_of(':'); if (!strings::safe_strto32(host_port.substr(colon_index + 1), &requested_port)) { return errors::Internal("Could not parse port for local server from \"", host_port, "\"."); } if (requested_port != bound_port_) { return errors::InvalidArgument("Requested port ", requested_port, " differs from expected port ", bound_port_); } // 得到 Worker Cache *worker_cache = NewGrpcWorkerCacheWithLocalWorker( channel_cache, grpc_worker_env(), worker_impl(), name_prefix); return Status::OK(); }
1.3.1 ParseChannelSpec
ParseChannelSpec 被用来得到 GrpcChannelSpec 实例,GrpcChannelSpec 等价于 ClusterSpec,其包含集群基本配置信息。
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { for (const auto& job : options.cluster_def->job()) { std::map<int, string> host_ports; for (const auto& task : job.tasks()) { string& host_port = host_ports[task.first]; if (!host_port.empty()) { return errors::InvalidArgument("JobDef for job \"", job.name(), "\" specified two addresses for task \"", task.first, "\": ", host_port, " and ", task.second); } if (job.name() == *options.job_name && task.first == options.task_index) { host_port = strings::StrCat(host_name_, ":", bound_port_); } else { host_port = task.second; } } TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports)); } return Status::OK(); }
1.3.2 NewGrpcChannelCache
NewGrpcChannelCache 用于创建 GrpcChannelCache 实例,可以看到,每个 Job 对应了一个 SparseGrpcChannelCache。如果只有一个 SparseGrpcChannelCache,则直接返回,否则把这些 SparseGrpcChannelCache 组合在一起构建一个 MultiGrpcChannelCache 返回。其中传入的channel_func 是 GetChannelCreationFunction。我们后续会介绍。
GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec, ChannelCreationFunction channel_func, const RPCOptions& options) { const int num_jobs = spec.host_ports_jobs().size(); if (!num_jobs) { return nullptr; } std::vector<GrpcChannelCache*> caches; caches.reserve(num_jobs); for (auto& job : spec.host_ports_jobs()) { caches.push_back( new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func, options.num_channels_per_target())); } return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache( caches, options.num_channels_per_target()); }
1.3.3 NewGrpcWorkerCacheWithLocalWorker
NewGrpcWorkerCacheWithLocalWorker 方法创建 GrpcWorkerCache 实例。
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker( std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env, WorkerInterface* local_worker, const string& local_target) { return new GrpcWorkerCache(cc, local_worker, local_target, worker_env); }
local_worker 参数是通过 worker_impl() 得到并且传入的,其生成是在 GrpcServer::Init 之中,就是本地的 GrpcWorker。
GrpcWorker* worker_impl() const { return worker_impl_.get(); } std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env, const ConfigProto& config) { return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config)); } Status GrpcServer::Init(const GrpcServerOptions& opts) { // 省略 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config) : NewGrpcWorker(&worker_env_, config); // 省略 }
我们梳理一下工厂类目前流程,可以看到,最开始输入是 WorkerCacheFactoryOptions,然后一步一步的通过各个函数的处理,最后生成了 GrpcWorkerCache。
图 1 工厂类流程
1.4 WorkerCacheInterface
1.4.1 接口
WorkerCacheInterface 是接口类,上面图之中 GrpcWorkerCache 就是这个接口的派生类。
class WorkerCacheInterface { public: virtual ~WorkerCacheInterface() {} // Updates *workers with strings naming the remote worker tasks to // which open channels have been established. virtual void ListWorkers(std::vector<string>* workers) const = 0; virtual void ListWorkersInJob(const string& job_name, std::vector<string>* workers) const = 0; // If "target" names a remote task for which an RPC channel exists // or can be constructed, returns a pointer to a WorkerInterface object // wrapping that channel. The returned value must be destroyed by // calling `this->ReleaseWorker(target, ret)` virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0; // Release a worker previously returned by this->GetOrCreateWorker(target). // // TODO(jeff,sanjay): Consider moving target into WorkerInterface. // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a // per-rpc-subsystem WorkerInterface creator. virtual void ReleaseWorker(const string& target, WorkerInterface* worker) { // Subclasses may override to reuse worker objects. delete worker; } // Set *locality with the DeviceLocality of the specified remote device // within its local environment. Returns true if *locality // was set, using only locally cached data. Returns false // if status data for that device was not available. Never blocks. virtual bool GetDeviceLocalityNonBlocking(const string& device, DeviceLocality* locality) = 0; // Set *locality with the DeviceLocality of the specified remote device // within its local environment. Callback gets Status::OK if *locality // was set. virtual void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, StatusCallback done) = 0; // TODO(b/189159585): Define a general client cache maker function to // construct client cache of different types sharing the same underling RPC // channels, to replace the eager and coordination cache function. // Build and return a EagerClientCache object wrapping that channel. virtual Status GetEagerClientCache( std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0; // Build and return a CoordinationClientCache object wrapping that channel. virtual Status GetCoordinationClientCache( std::unique_ptr<CoordinationClientCache>* coordination_client_cache) = 0; // Start/stop logging activity. virtual void SetLogging(bool active) {} // Discard any saved log data. virtual void ClearLogs() {} // Return logs for the identified step in *ss. Any returned data will no // longer be stored. virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; } };
WorkerCachePartial 又继承了 WorkerCacheInterface。
// Implements the part of the interface that caches and returns remote // device status attributes. class WorkerCachePartial : public WorkerCacheInterface { public: bool GetDeviceLocalityNonBlocking(const string& device, DeviceLocality* locality) override; void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, StatusCallback) override; ~WorkerCachePartial() override {} // Clear all entries from the DeviceStatus cache. void FlushStatusCache(); private: mutex mu_; // Initiate a GetStatusAsync to the remote task named by "task", and // update the cache with all the DeviceAttributes reported. Status RefreshDeviceStatus(const string& device_name); typedef std::unordered_map<string, DeviceAttributes> StatusMap; StatusMap device_status_cache_ TF_GUARDED_BY(mu_); };
1.4.2 GrpcWorkerCache
GrpcWorkerCache 则继承了 WorkerCachePartial。
class GrpcWorkerCache : public WorkerCachePartial { public: explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache, WorkerInterface* local_worker, const string& local_target, GrpcWorkerEnv* worker_env) : local_target_(local_target), local_worker_(local_worker), channel_cache_(channel_cache), worker_env_(worker_env), next_round_robin_assignment_(0) {} const string local_target_; WorkerInterface* const local_worker_; // Not owned. std::shared_ptr<GrpcChannelCache> channel_cache_; WorkerCacheLogger logger_; GrpcWorkerEnv* worker_env_; // Not owned mutex assignment_mu_; std::unordered_map<std::string, size_t> target_assignments_ TF_GUARDED_BY(assignment_mu_); size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_); };
其主要功能是使用 ListWorkers 罗列出集群内所有 worker 的名字。
void ListWorkers(std::vector<string>* workers) const override { channel_cache_->ListWorkers(workers); } void ListWorkersInJob(const string& job_name, std::vector<string>* workers) const override { channel_cache_->ListWorkersInJob(job_name, workers); }
GetOrCreateWorker 会根据 Worker 的 RPC 通道建立 worker,如果是本地,则直接返回 local_worker_,就是我们前面设置的本地 GrpcWorker。
WorkerInterface* GetOrCreateWorker(const string& target) override { if (target == local_target_) { return local_worker_; } else { SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target); if (!channel) { return nullptr; } size_t index = AssignWorkerToThread(target); return NewGrpcRemoteWorker( channel, worker_env_->GetCompletionQueue(index), worker_env_->GetThreadPool(), &logger_, target); } }
2. RPC 通道
Worker 运行在 RPC 通道之上,所以我们接下来看看如何建立这个 RPC 通道。因为 Worker 有缓存,同样的,RPC 通道也有缓存。GrpcChannelCache 就是这个缓存,其被用来获取/创建集群之中远端 Worker 的 RPC 通道。
2.1 GrpcChannelCache 接口
GrpcChannelCache 是接口类,定义了一系列接口,比如:
- ListWorkers 可以返回集群之中的 Worker 名称。
- TranslateTask :把 Worker 名字 转换为地址信息,格式是 host:port。
- FindWorkerChannel :从缓存中查找 grpc::Channel 实例,如果缓存之中没有,就依据地址信息动态生成一个实例,再将其放入缓存。
class GrpcChannelCache { public: virtual ~GrpcChannelCache() {} // Populates *workers with names of all workers which this object // was created to handle. Worker names are in the format // /job:<job identifier>/task:<task id> // e.g. /job:mnist/task:2 virtual void ListWorkers(std::vector<string>* workers) = 0; virtual void ListWorkersInJob(const string& job_name, std::vector<string>* workers) = 0; // If found, returns a gRPC channel that is connected to the remote // worker named by 'target'. 'target' is of the following // format: /job:<job identifier>/task:<task id> // E.g., /job:mnist/task:2 virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0; // Translates a string in the form `/job:X/task:Z` into a host_port. virtual string TranslateTask(const string& task) = 0; };
2.2 缓存机制
CachingGrpcChannelCache 是缓存类,可以避免每次创建 grpc::Channel 的开销。其定义如下,具体就是派生了 GrpcChannelCache 的 GenericCachingChannelCache。
// GrpcChannelCache that caches results to FindWorkerChannel() calls. using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;
GenericCachingChannelCache,用于缓存FindWorkerChannel()调用的结果,首先从缓存中查找 grpc::Channel 实例,如果缓存之中没有,就依据地址信息调用 FindChannelOnce 动态生成一个实例,再将其放入缓存。
GenericCachingChannelCache 允许使用多个通道与同一目标通信以提高吞吐量。当同一目标存在多个通道时,每次调用FindWorkerChannel时,都会以 round robin 循环方式选择这些通道。
注意,因为有如下定义,所以 absl::flat_hash_map<string, ChannelState> channels_ 就是 grpcChannel 缓存 集合。
typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
具体代码是:
template <typename ChannelCacheT> class GenericCachingChannelCache : public ChannelCacheT { public: explicit GenericCachingChannelCache(int num_channels_per_target) : num_channels_per_target_( num_channels_per_target > 0 ? num_channels_per_target : 1) {} ~GenericCachingChannelCache() override {} SharedGrpcChannelPtr FindWorkerChannel(const string& target) override { { mutex_lock l(mu_); auto iter = channels_.find(target); if (iter != channels_.end()) { return GetNextChannelPtrAndUpdateState(iter->second); } } ChannelState new_chan_state; for (int indx = 0; indx < num_channels_per_target_; indx++) { auto ch = FindChannelOnce(target); if (!ch) return nullptr; new_chan_state.channels.push_back(ch); } new_chan_state.last_used = num_channels_per_target_ - 1; { mutex_lock l(mu_); typename absl::flat_hash_map<string, ChannelState>::iterator iter; bool was_inserted; std::tie(iter, was_inserted) = channels_.insert({target, new_chan_state}); return GetNextChannelPtrAndUpdateState(iter->second); } } protected: // Find the ClientChannel for "target". Only called when no channel was // found in the channels_ cache for "target". A non nullptr result will be // cached in channels_. virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0; private: struct ChannelState { std::vector<SharedGrpcChannelPtr> channels; int last_used; }; // Should be called with mu_ held. SharedGrpcChannelPtr GetNextChannelPtrAndUpdateState( ChannelState& chan_state) { // Following statement is marked as Crash OK as this is an invariant of // code flow in this class. CHECK_EQ(chan_state.channels.size(), num_channels_per_target_); // Crash OK chan_state.last_used = (chan_state.last_used + 1) % num_channels_per_target_; return chan_state.channels[chan_state.last_used]; } const int num_channels_per_target_; // TODO(zhifengc): Eviction when the map becomes too big. mutex mu_; absl::flat_hash_map<string, ChannelState> channels_ TF_GUARDED_BY(mu_); };
2.3 业务派生类
从 CachingGrpcChannelCache 又派生出了两个类,具体如下:
2.3.1 叶子节点
SparseGrpcChannelCache 是叶子结点,集群之中每个 Job 对应了一个 SparseGrpcChannelCache,SparseGrpcChannelCache 内部的 grpcChannel 集合就是 Job 的 Task 对应的 grpcChannel 集合,每个 Task 对应一个 grpc::Channel 。
SparseGrpcChannelCache 主要变量如下:
- const string job_id_ :本类对应了哪一个 Job。
- const std::map<int, string> host_ports_ :本 Job 对应 Task 的 host:port 列表。
- const ChannelCreationFunction channel_func_ :生成 grpc:Channel 的方法。
SparseGrpcChannelCache 主要功能如下:
- ListWorkers :该方法返回本 Job 对应的 Task 名称列表。
- TranslateTask:依据某个 Task 名字来得到其地址信息(格式为host:port ),例如, /job:ps/replica:1/task:1 的地址可能就是 ps1:1111;
- FindChannelOnce :依据某个 Task 名字来创建对应的 grpcChannel。具体是先通过 TranslateTask 获取到 worker 对应的 task id,然后得到地址信息,最后用地址信息来构建 grpcChannel。
class SparseGrpcChannelCache : public CachingGrpcChannelCache { public: SparseGrpcChannelCache(const string& job_id, const std::map<int, string>& host_ports, ChannelCreationFunction channel_func, int num_channels_per_target) : CachingGrpcChannelCache(num_channels_per_target), job_id_(job_id), host_ports_(host_ports), channel_func_(std::move(channel_func)) { } ~SparseGrpcChannelCache() override {} void ListWorkers(std::vector<string>* workers) override { workers->reserve(workers->size() + host_ports_.size()); for (const auto& id_host_port : host_ports_) { workers->emplace_back(MakeAddress(job_id_, id_host_port.first)); } } void ListWorkersInJob(const string& job_name, std::vector<string>* workers) override { if (job_name == job_id_) { ListWorkers(workers); } } string TranslateTask(const string& target) override { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(target, &parsed)) { return ""; } if (!parsed.has_job || parsed.job != job_id_) { return ""; } if (!parsed.has_replica || parsed.replica != 0) { return ""; } int32_t task = parsed.has_task ? parsed.task : -1; auto iter = host_ports_.find(task); if (iter == host_ports_.end()) { return ""; } return iter->second; } protected: SharedGrpcChannelPtr FindChannelOnce(const string& target) override { const string host_port = TranslateTask(target); if (host_port.empty()) { if (host_port.empty()) { return nullptr; } auto chan_ptr = channel_func_(host_port); return chan_ptr; } private: const string job_id_; const std::map<int, string> host_ports_; const ChannelCreationFunction channel_func_; TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache); };
2.3.2 非叶子结点
为了提高 SparseGrpcChannelCache 查找过程以及对集群所有 Worker 节点 的组合管理,TF 把 集群内的 SparseGrpcChannelCache 组合起来,构建了 MultiGrpcChannelCache。MultiGrpcChannelCache 会把访问过的 SparseGrpcChannelCache 缓存起来。
// A ChannelCache that is the union of multiple ChannelCaches. // Takes ownership of the caches passed to the constructor. class MultiGrpcChannelCache : public CachingGrpcChannelCache { public: explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches, int num_channels_per_target) : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {} ~MultiGrpcChannelCache() override { for (GrpcChannelCache* cache : caches_) { delete cache; } } void ListWorkers(std::vector<string>* workers) override { for (GrpcChannelCache* cache : caches_) { cache->ListWorkers(workers); } } void ListWorkersInJob(const string& job_name, std::vector<string>* workers) override { for (GrpcChannelCache* cache : caches_) { cache->ListWorkersInJob(job_name, workers); } } string TranslateTask(const string& target) override { mutex_lock l(mu_); // could use reader lock GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target); if (cache == nullptr) { for (GrpcChannelCache* c : caches_) { string r = c->TranslateTask(target); if (!r.empty()) { target_caches_.insert({target, c}); cache = c; break; } } } return cache->TranslateTask(target); } protected: SharedGrpcChannelPtr FindChannelOnce(const string& target) override { for (GrpcChannelCache* cache : caches_) { SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target)); if (ch) { mutex_lock l(mu_); target_caches_.insert({target, cache}); return ch; } } return nullptr; } private: // List of channels used by this MultiGrpcChannelCache. const std::vector<GrpcChannelCache*> caches_; mutex mu_; // Cache of channels keyed by the target they are handling. // The same GrpcChannelCache can appear multiple times in the cache. std::unordered_map<string, GrpcChannelCache*> target_caches_ TF_GUARDED_BY(mu_); };
目前结构如下:
图 2 缓存逻辑关系
2.4 生成 GrpcChannelCache
前面在生成 GrpcChannelCache 时候,传入了 GetChannelCreationFunction,当时没有介绍,我们现在梳理一下。
// 得到 GrpcChannelCache std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache( channel_spec, GetChannelCreationFunction(), *options.rpc_options));
2.4.1 目标&使用
我们首先看看如何使用或者说目标,就是通过 target(host:port类型的字符串)来生成一个 SharedGrpcChannelPtr,我们知道,SharedGrpcChannelPtr 就是 grpc::Channel。
SharedGrpcChannelPtr FindChannelOnce(const string& target) override { const string host_port = TranslateTask(target); if (host_port.empty()) { if (host_port.empty()) { return nullptr; } auto chan_ptr = channel_func_(host_port); VLOG(5) << "Channel created for: job: " << job_id_ << " host_port: " << host_port << " target : " << target << " Ptr: " << chan_ptr.get(); return chan_ptr; }
2.4.2 NewHostPortGrpcChannel
首先要介绍 NewHostPortGrpcChannel,NewHostPortGrpcChannel 是 TF 现存的 API。其主要作用是调用 grpcCreateCustomChannel(gRPC API)得到一个 grpcChannel,配置到 SharedGrpcChannelPtr* channel_pointer 之上,然后返回 channel_pointer(也就是 grpcChannel)。这个方法的返回结果是我们满意的,但是调用方法不对,需要封装或转换一下。
Status NewHostPortGrpcChannel(const string& target, const RPCOptions* rpc_options, SharedGrpcChannelPtr* channel_pointer) { // Minimally ensure that the target is valid TF_RETURN_IF_ERROR(ValidateHostPortPair(target)); ::grpc::ChannelArguments args = GetChannelArguments(rpc_options); *channel_pointer = ::grpc::CreateCustomChannel( "dns:///" + target, ::grpc::InsecureChannelCredentials(), args); return Status::OK(); }
2.4.3 ConvertToChannelCreationFunction
ConvertToChannelCreationFunction 方法是用来把传入的 new_channel_func_ptr 方法转换一下,把 new_channel_func_ptr 变成一个只需要传入 const string& target 就可以生成 SharedGrpcChannelPtr 的方法。
ChannelCreationFunction ConvertToChannelCreationFunction( const std::function<Status(string, const RPCOptions*, SharedGrpcChannelPtr*)>& new_channel_func_ptr) { return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr { SharedGrpcChannelPtr channel_ptr; if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr) .ok()) { return channel_ptr; } else { return nullptr; } }; }
2.4.4 GetChannelCreationFunction
GetChannelCreationFunction 就是使用 NewHostPortGrpcChannel 作为传入参数,得到一个 ConvertToChannelCreationFunction 的方法,因为这个方法才是可以被 WorkerCache工厂类利用的方法。
ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const { // We can do this because SparseGrpcChannelCache is robust to nullptr being // returned by the channel creation function return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); }
2.4.5 使用分析
回到我们的调用。channel_func_ 就是 GetChannelCreationFunction,于是直接调用就可以得到 grpc::Channel。
SharedGrpcChannelPtr FindChannelOnce(const string& target) override { const string host_port = TranslateTask(target); auto chan_ptr = channel_func_(host_port); }
至此,我们拓展之前的逻辑如下,中间增加了一个步骤,通过传入 target 就可以得到 grpc::Channel:
图 3 如何转换
3. Cache 在系统中的位置
我们虽然总结了 Cache 如何初始化,如何使用,但是我们迷失了 Cache 在系统之中的位置,现在我们看看究竟在系统之中,Cache 处于什么位置。GrpcWorkerCache 内部的 GrpcChannelCache 指向了系统内部的 gRPC Channel Cache,用来获取缓存的 gRPC 通道。local_worker 存储了本地 Worker。
图 4 Cache 的位置
当调用 GrpcWorkerCache 的 GetOrCreateWorker 时候,如果 target 是本地,就直接返回 local_worker(就是我们前面设置的本地 GrpcWorker),否则根据 Worker 的 RPC 通道来生成一个远端 GrpcRemoteWorker。
图 5 生成 worker
在 Master,Worker,MasterSesision,WorkerSession 之中,处处可见 WorkerCacheInterface(也就是GrpcWorkerCache)的身影,很多类都有一个指向 WorkerCacheInterface 的成员变量,使用相当广泛。
4. 查找设备集
为了创建 WorkerSession,MasterSession 需要知道远端所有 Worker 之上的设备集合,所以 Master 会在创建 MasterSession 之前遍历所有 Worker,获取其上的设备信息,因为其利用了 GrpcWorkerCache 的功能,所以我们在这里一起讲解。基本逻辑如下:
- 根据 GrpcWorkerCache::ListWorkers 获取集群中所有 Worker 的名字。
- 依据 worker_name 调用 GetOrCreateWorker 在 worker_cache 内部查找 WorkerInterface 对象,如果有就获取,没有就构建。
- 然后构建 GetStatusRequest,发送给找到的 Worker,具体通过 GetStatusAsync 完成。
- Worker 返回 GetStatusResponse 之后,将调用回调函数 cb (WhenFound方法)之中的函数对象来获取 Worke 的设备信息。这里需要对获取到的设备信息进行处理,添加 worker_name。
图 6 获取设备
4.1 DeviceFinder
4.1.1 定义
DeviceFinder 是一个函数对象,实现了查找远端worker设备的算法,我们先给出成员变量如下:
class DeviceFinder { ~DeviceFinder() { for (Device* dev : found_) delete dev; } typedef DeviceFinder ME; const MasterEnv* env_; WorkerCacheInterface* worker_cache_; std::vector<DeviceNameUtils::ParsedName> filters_; mutex mu_; int num_pending_ TF_GUARDED_BY(mu_); condition_variable pending_zero_; std::vector<Device*> found_ TF_GUARDED_BY(mu_); // List of targets to be contacted by this DeviceFinder. The // respective `bool` in `seen_targets_` indicates whether we have // heard from this target or not. std::vector<string> targets_; std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_); Status status_; TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder); };
4.1.2 初始化
主要逻辑是:根据 GrpcWorkerCache::ListWorkers 获取集群中所有的 Worker 的名字列表。
explicit DeviceFinder( const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env, WorkerCacheInterface* worker_cache) : env_(env), worker_cache_(worker_cache) { CHECK(worker_cache) << "Worker cache was null!"; auto process_filter = [this](const string& filter) { DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(filter, &parsed)) { filters_.push_back(parsed); } else { LOG(FATAL) << "Skipping invalid filter: " << filter; } }; for (const string& filter : device_filters) { process_filter(filter); } // Enumerates all known workers' target. A target name is a // prefix of a device name. E.g., /job:mnist/replica:0/task:10. if (filters_.empty()) { // If no filters were specified, we list all known workers in // `worker_cache`. std::vector<string> workers; worker_cache->ListWorkers(&workers); std::swap(workers, targets_); } else { // When applying filters, we must include the local worker, even if it // does not match any of the filters. CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided."; const string& local_device_name = env_->local_devices[0]->name(); DeviceNameUtils::ParsedName local_parsed_name; CHECK(DeviceNameUtils::ParseFullName(local_device_name, &local_parsed_name)); bool all_filters_have_job = true; std::unordered_set<string> filter_job_names({local_parsed_name.job}); for (const DeviceNameUtils::ParsedName& filter : filters_) { all_filters_have_job = all_filters_have_job && filter.has_job; if (filter.has_job) { filter_job_names.insert(filter.job); } } std::vector<string> workers; if (all_filters_have_job) { // If all of the device filters have a job specified, then we only need // to list the workers in the jobs named in the filter, because a worker // in any other job would not match any filter. for (const string& job_name : filter_job_names) { VLOG(2) << "Selectively listing workers in job: " << job_name; std::vector<string> workers_in_job; worker_cache->ListWorkersInJob(job_name, &workers_in_job); workers.insert(workers.end(), workers_in_job.begin(), workers_in_job.end()); } } else { // If any of the device filters does not have a job specified, then we // must list the workers from all jobs. VLOG(2) << "Listing workers in all jobs because some device " << "filter has no job specified. Filters were:"; if (device_filters.empty()) { VLOG(2) << "- <NO FILTERS>"; } else { for (const string& filter : device_filters) { VLOG(2) << "- " << filter; } } worker_cache->ListWorkers(&workers); } for (const string& name : workers) { if (MatchFilters(name) || DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) { targets_.push_back(name); } } } seen_targets_.assign(targets_.size(), false); }
4.1.3 GetRemoteDevices
GetRemoteDevices 方法会获取远端设备,逻辑如下:
- 利用 finder.Start() 来给集群内部所有 Worker 广播 GetStatusRequest。
- 利用 finder.Wait() 收集所有 Worker 返回的 GetStatusResponse 消息。
- 利用 finder.GetRemoteDevices 获取查询结果,并且返回给客户。
static Status GetRemoteDevices( const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env, WorkerCacheInterface* worker_cache, std::vector<std::unique_ptr<Device>>* out_remote) { DeviceFinder finder(device_filters, env, worker_cache); finder.Start(); TF_RETURN_IF_ERROR(finder.Wait()); finder.GetRemoteDevices(env->local_devices, out_remote); return Status::OK(); }
4.1.3.1 Start
Start 方法会把计数器 num_pending_ 初始化为 Worker 数目,然后遍历 Worker,逐一调用 NewRemoteDevices 进行处理。
void Start() { { mutex_lock l(mu_); num_pending_ = targets_.size(); if (num_pending_ == 0) { pending_zero_.notify_all(); } } // Talk to all workers to get the list of available devices. using std::placeholders::_1; using std::placeholders::_2; for (size_t i = 0; i < targets_.size(); ++i) { // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may // never be called. NewRemoteDevices(env_->env, worker_cache_, targets_[i], std::bind(&ME::WhenFound, this, i, _1, _2)); } }
NewRemoteDevices 逻辑如下:
- 依据 worker_name 调用 GetOrCreateWorker 在 worker_cache 内部查找 WorkerInterface 对象,如果有就获取,没有就构建。
- 然后构建 GetStatusRequest,发送给找到的 Worker,具体通过 GetStatusAsync 完成。
- Worker 返回 GetStatusResponse 之后,将调用回调函数 cb (WhenFound方法)之中的函数对象来获取 Worke 的设备信息。这里需要对获取到的设备信息进行处理,添加 worker_name。
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, const string& worker_name, NewRemoteDevicesDone done) { WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name); if (wi == nullptr) { std::vector<Device*> empty; done(errors::NotFound("Device ", worker_name, " is not found."), &empty); return; } struct Call { GetStatusRequest req; // 发送消息 GetStatusResponse resp; // 相应消息 }; Call* call = new Call; // 回调函数 auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& status) { Status s = status; std::vector<Device*> remote_devices; auto cleanup = gtl::MakeCleanup( [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] { worker_cache->ReleaseWorker(worker_name, wi); done(s, &remote_devices); delete call; }); if (s.ok()) { DeviceNameUtils::ParsedName worker_name_parsed; if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) || !worker_name_parsed.has_job || !worker_name_parsed.has_replica || !worker_name_parsed.has_task) { s = errors::InvalidArgument("Could not parse worker name: ", worker_name); return; } remote_devices.reserve(call->resp.device_attributes_size()); for (const DeviceAttributes& da : call->resp.device_attributes()) { DeviceNameUtils::ParsedName device_name_parsed; CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed)) << "Device attribute name '" << da.name() << "' could not be " << "parsed. Device Attribute: " << da.DebugString(); // Preserve the exact name, if possible. if (device_name_parsed.job == worker_name_parsed.job && device_name_parsed.replica == worker_name_parsed.replica && device_name_parsed.task == worker_name_parsed.task) { auto d = new RemoteDevice(env, da); remote_devices.push_back(d); } else { DeviceAttributes da_rewritten = da; da_rewritten.set_name(DeviceNameUtils::FullName( worker_name_parsed.job, worker_name_parsed.replica, worker_name_parsed.task, device_name_parsed.type, device_name_parsed.id)); auto d = new RemoteDevice(env, da_rewritten); // Experimental: Skipping over adding any TPU-type devices that aren't // on the job called "worker" (but still adds the CPUs of other jobs). if (getenv("TPU_NO_POPULATE_DEVICE_LIST_FROM_CLUSTER_SPEC") != nullptr) { if (worker_name_parsed.job == "worker" || device_name_parsed.type.find("TPU") == std::string::npos) { remote_devices.push_back(d); } } else { remote_devices.push_back(d); } } } } }; wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp, /*fail_fast=*/false, cb); }
4.1.3.2 Wait
Wait 方法之中,如果计数器不为 0,则一直调用 pending_zero_.wait_for 等待,期间主线程会周期性睡眠 10 秒钟。
Status Wait() { mutex_lock l(mu_); // TODO(mrry): Propagate a timeout here, since `num_pending_` may // never become zero. while (num_pending_ != 0) { pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs)); if (num_pending_ != 0) { for (size_t i = 0; i < targets_.size(); ++i) { if (!seen_targets_[i]) { LOG(INFO) << "CreateSession still waiting for response from worker: " << targets_[i]; } } } } return status_; }
4.1.3.3 回调函数
Start 的回调函数如下,如果收到了某个 Worker 的GetStatusResponse 消息,则 Start 会调用到此。WhenDone将计数器减 1,如果计数器为 0,则调用 pending_zero_.notify_all(),这样 wait 之中的 pending_zero_.wait_for 语句 会被唤醒,GetRemoteDevices 方法就会利用 finder.GetRemoteDevices 获取查询结果,并且返回给客户。
void WhenFound(int target_index, const Status& s, std::vector<Device*>* devices) { mutex_lock l(mu_); seen_targets_[target_index] = true; if (!s.ok()) { LOG(ERROR) << "CreateSession failed because worker " << targets_[target_index] << " returned error: " << s; status_.Update(s); } else { found_.insert(found_.end(), devices->begin(), devices->end()); devices->clear(); } --num_pending_; if (num_pending_ == 0) { pending_zero_.notify_all(); } }
4.2 Worker 交互
NewRemoteDevices 之中会通过 GetStatusAsync 来构建 GetStatusRequest,发送给找到的 Worker。
WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name); wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp, /*fail_fast=*/false, cb);
4.2.1 GrpcRemoteWorker
wi 就是找到的 WorkerInterface,实际就是 GrpcRemoteWorker,这是 gRPC 的客户端,通过 stub 调用远端 WorkerService 相应的服务接口。
void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { IssueRequest(request, response, getstatus_, std::move(done), call_opts, fail_fast); }
4.2.2 GrpcWorkerService
远端 Worker 之中,接收到消息是在 GrpcWorkerService 之中,当收到 GetStatusRequest 消息,将 由 GetStatusHandler 回调处理,GetStatusHandler 是一个宏。
#define HANDLE_CALL(method, may_block_on_compute_pool) \ void method##Handler(WorkerCall<method##Request, method##Response>* call) { \ auto closure = [this, call]() { \ Status s = worker_->method(&call->request, &call->response); \ if (!s.ok()) { \ VLOG(3) << "Bad response from " << #method << ": " << s; \ } \ call->SendResponse(ToGrpcStatus(s)); \ }; \ if ((may_block_on_compute_pool)) { \ worker_->env()->env->SchedClosure(std::move(closure)); \ } else { \ worker_->env()->compute_pool->Schedule(std::move(closure)); \ } \ ENQUEUE_REQUEST(method, false); \ } HANDLE_CALL(GetStatus, false);
4.2.3 Worker
最后来到 Worker 类,其实它也只是转交给 DeviceMgr,并最终通过 GetStatusResponse 消息返回给远端调用方。
void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) { const DeviceMgr* dm = env_->device_mgr; std::vector<DeviceAttributes> devices; dm->ListDeviceAttributes(&devices); response->mutable_device_attributes()->Reserve(devices.size()); for (auto& d : devices) { response->add_device_attributes()->Swap(&d); } done(Status::OK()); }
4.2.4 DeviceMgr
ListDeviceAttributes 有两种本地设备信息汇总的实现,具体如下。
void StaticDeviceMgr::ListDeviceAttributes( std::vector<DeviceAttributes>* devices) const { devices->reserve(devices_.size()); for (const auto& dev : devices_) { devices->emplace_back(dev->attributes()); } }
实现 2 如下:
void DynamicDeviceMgr::ListDeviceAttributes( std::vector<DeviceAttributes>* devices) const { tf_shared_lock l(devices_mu_); devices->reserve(dynamic_devices_.size()); for (const auto& d : dynamic_devices_) { devices->emplace_back(d->attributes()); } }
至此,我们分析完了 Cache 和查找设备集,接下来我们去看看业务如何处理。
0xFF 参考
https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/
什么是in-graph replication和between-graph replication?
[腾讯机智] TensorFlow源码解析(1): 创建会话
TensorFlow 分布式(Distributed TensorFlow)
tensorflow源码解析之distributed_runtime
Distributed TensorFlow: A Gentle Introduction
TensorFlow中的Placement启发式算法模块——Placer
TensorFlow的图切割模块——Graph Partitioner
- [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- Hystrix源码解析 —— 调试环境搭建
- 分布式消息队列 RocketMQ源码解析:事务消息
- 代码 - Java - 分布式事务 - Seata源码解析 - 服务器
- TensorFlow安装(安装bazel,配置TensorFlow编译环境,利用源码安装TensorFlow)
- 分布式调度框架TBSchedule源码解析
- 分布式消息队列 RocketMQ源码解析:事务消息
- Resilience4j源码解析-1 介绍及环境搭建
- TensorFlow学习之分布式的TensorFlow运行环境
- [源码解析] PyTorch 分布式(4)------分布式应用基础概念