Distributed TensorFlow allows us to share parts of a TensorFlow graph between multiple processes, possibly each on a different machine.

Why might we want to do this? The classic use case is to harness the power of multiple machines for training, with shared parameters between all machines. Even if we're just running on a single machine, though, I've come across two examples from reinforcement learning where sharing between processes has been necessary:

  • In A3C, multiple agents run in parallel in multiple processes, exploring different copies of the environment at the same time. Each agent generates parameter updates, which must also be sent to the other agents in different processes.
  • In Deep Reinforcement Learning from Human Preferences, one process runs an agent exploring the environment, with rewards calculated by a network trained from human preferences about agent behaviour. This network is trained asynchronously from the agent, in a separate process, so that the agent doesn't have to wait for each training cycle to complete before it can continue exploring.

Unfortunately, the official documentation on Distributed TensorFlow rather jumps in at the deep end. For a slightly more gentle introduction - let's run through some really basic examples with Jupyter.

If you'd like to follow along with this notebook interactively, sources can be found at GitHub.

(Note: some of the explanations here are my own interpretation of empirical results/TensorFlow documentation. If you see anything that's wrong, give me a shout!)

Introduction

import tensorflow as tf

Let's say we want multiple processes to have shared access to some common parameters. For simplicity, suppose this is just a single variable:

var = tf.Variable(initial_value=0.0)

As a first step, we can imagine that each process would need its own session. (Pretend session 1 is created in one process, and session 2 in another.)

sess1 = tf.Session()
sess2 = tf.Session()

sess1.run(tf.global_variables_initializer())
sess2.run(tf.global_variables_initializer())

Each call to tf.Session() creates a separate execution engine, then connects the session handle to the execution engine. The execution engine is what actually stores variable values and runs operations.

Normally, execution engines in different processes are unlinked. Changing var in one session (on one execution engine) won't affect var in the other session.

print("Initial value of var in session 1:", sess1.run(var))
print("Initial value of var in session 2:", sess2.run(var))

sess1.run(var.assign_add(1.0))
print("Incremented var in session 1")

print("Value of var in session 1:", sess1.run(var))
print("Value of var in session 2:", sess2.run(var))
Initial value of var in session 1: 0.0
Initial value of var in session 2: 0.0
Incremented var in session 1
Value of var in session 1: 1.0
Value of var in session 2: 0.0

Distributed TensorFlow

In order to share variables between processes, we need to link the different execution engines together. Enter Distributed TensorFlow.

With Distributed TensorFlow, each process runs a special execution engine: a TensorFlow server. Servers are linked together as part of a cluster. (Each server in the cluster is also known as a task.)

The first step is to define what the cluster looks like. We start off with the simplest possible cluster: two servers (two tasks), both on the same machine; one that will listen on port 2222, one on port 2223.

tasks = ["localhost:2222", "localhost:2223"]

Each task is associated with a job, which is a collection of related tasks. We associate both tasks with a job called "local".

jobs = {"local": tasks}

This completes the definition of the cluster.

cluster = tf.train.ClusterSpec(jobs)

We can now launch the servers, specifying which server in the cluster definition each server corresponds to. Each server starts immediately, listening on the port specified in the cluster definition.

# "This server corresponds to the the first task (task_index=0)
# of the tasks associated with the 'local' job."
server1 = tf.train.Server(cluster, job_name="local", task_index=0)

server2 = tf.train.Server(cluster, job_name="local", task_index=1)

With the servers linked together in the same cluster, we can now experience the main magic of Distributed TensorFlow: any variable with the same name will be shared between all servers.

The simplest example is to run the same graph on all servers, each graph with just one variable, as before:

tf.reset_default_graph()
var = tf.Variable(initial_value=0.0, name='var')
sess1 = tf.Session(server1.target)
sess2 = tf.Session(server2.target)

Modifications made to the variable on one server will now be mirrored on the second server.

sess1.run(tf.global_variables_initializer())
sess2.run(tf.global_variables_initializer())

print("Initial value of var in session 1:", sess1.run(var))
print("Initial value of var in session 2:", sess2.run(var))

sess1.run(var.assign_add(1.0))
print("Incremented var in session 1")

print("Value of var in session 1:", sess1.run(var))
print("Value of var in session 2:", sess2.run(var))
Initial value of var in session 1: 0.0
Initial value of var in session 2: 0.0
Incremented var in session 1
Value of var in session 1: 1.0
Value of var in session 2: 1.0

(Note that because we only have one variable, and that variable is shared between both sessions, the second run of global_variables_initializer here is redundant.)

Placement

A question that might be in our minds at this point is: which server does the variable actually get stored on? And for operations, which server actually runs them?

Empirically, it seems that by default, variables and operations get placed on the first task in the cluster.

def run_with_location_trace(sess, op):
    # From https://stackoverflow.com/a/41525764/7832197
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    sess.run(op, options=run_options, run_metadata=run_metadata)
    for device in run_metadata.step_stats.dev_stats:
      print(device.device)
      for node in device.node_stats:
        print("  ", node.node_name)

For example, if we do something to var using the session connected to the first task, everything happens on that task:

run_with_location_trace(sess1, var)
/job:local/replica:0/task:0/device:CPU:0
   _SOURCE
   var
run_with_location_trace(sess1, var.assign_add(1.0))
/job:local/replica:0/task:0/device:CPU:0
   _SOURCE
   AssignAdd_1/value
   var
   AssignAdd_1

But if we try and try do something to var using the session connected to the second task, the graph nodes still get run on the first task.

run_with_location_trace(sess2, var)
/job:local/replica:0/task:1/device:CPU:0
   _SOURCE
/job:local/replica:0/task:0/device:CPU:0
   _SOURCE
   var

To fix a variable or an operation to a specific task, we can use tf.device:

with tf.device("/job:local/task:0"):
    var1 = tf.Variable(0.0, name='var1')
with tf.device("/job:local/task:1"):
    var2 = tf.Variable(0.0, name='var2')
    
# (This will initialize both variables)
sess1.run(tf.global_variables_initializer())

Now var1 runs on the first task, as before.

run_with_location_trace(sess1, var1)
/job:local/replica:0/task:0/device:CPU:0
   _SOURCE
   var1

But var2 runs on the second task. Even if we try to evaluate it using the session connected to the first task, it still runs on the second task.

run_with_location_trace(sess1, var2)
/job:local/replica:0/task:0/device:CPU:0
   _SOURCE
/job:local/replica:0/task:1/device:CPU:0
   _SOURCE
   var2

And vice-versa with var2.

run_with_location_trace(sess2, var2)
/job:local/replica:0/task:1/device:CPU:0
   _SOURCE
   var2
run_with_location_trace(sess2, var1)
/job:local/replica:0/task:1/device:CPU:0
   _SOURCE
/job:local/replica:0/task:0/device:CPU:0
   _SOURCE
   var1

Graphs

There are a couple of things to note about how graphs work with Distributed TensorFlow.

Who builds the graph?

First, although variable values are shared throughout the cluster, the graph is not automatically shared.

Let's create a fresh cluster with two servers, and set up the first server with an explicitly-created graph.

cluster = tf.train.ClusterSpec({"local": ["localhost:2224", "localhost:2225"]})
server1 = tf.train.Server(cluster, job_name="local", task_index=0)
server2 = tf.train.Server(cluster, job_name="local", task_index=1)
graph1 = tf.Graph()
with graph1.as_default():
    var1 = tf.Variable(0.0, name='var')
sess1 = tf.Session(target=server1.target, graph=graph1)
print(graph1.get_operations())
[<tf.Operation 'var/initial_value' type=Const>, <tf.Operation 'var' type=VariableV2>, <tf.Operation 'var/Assign' type=Assign>, <tf.Operation 'var/read' type=Identity>]

If we then create a session connected to the second server, note that the graph does not automatically get mirrored.

graph2 = tf.Graph()
sess2 = tf.Session(target=server2.target, graph=graph2)
print(graph2.get_operations())
[]

To access the shared variable, we must manually add a variable with the same name to the second graph.

with graph2.as_default():
    var2 = tf.Variable(0.0, name='var')

Only then can we access it.

sess1.run(var1.assign(1.0))
sess2.run(var2)
1.0

The takeaway is: each server is responsible for building its own graph.

Does the graph have to be the same on all servers?

So far, all our examples have run the same graph structure on both servers. This is known as in-graph replication.

For example, let's say we have a cluster containing three servers. Server 1 holds shared parameters, while server 2 and server 3 are worker nodes, each with local variables. With in-graph replication, each server's graphs would look like:

The issue with in-graph replication is that every server has to have a copy of the entire graph, including the parts of the graph that might only be relevant for other servers. This can lead to graphs growing very large.

The alternative is between-graph replication. Here, each server runs a graph containing only the shared parameters, and whatever variables and operations are relevant to that individual server.

Because it keeps graph sizes smaller, between-graph replication is the recommended approach.

Practical Details

Before we finish with a full example, there are a few practical details to discuss.

What happens if we try to run something on the cluster before all servers have connected?

Let's create another two-task cluster.

cluster = tf.train.ClusterSpec({
    "local": ["localhost:2226", "localhost:2227"]
})

This time, let's start each server in a separate process. (This allows us to kill the servers, so that we can start them again for later experiments. There's currently no way of killing servers other than killing the process which started them.)

from multiprocessing import Process
from time import sleep

def s1():
    server1 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=0)
    sess1 = tf.Session(server1.target)
    print("server 1: running no-op...")
    sess1.run(tf.no_op())
    print("server 1: no-op run!")
    server1.join() # Block

def s2():
    for i in range(3):
        print("server 2: %d seconds left before connecting..."
              % (3 - i))
        sleep(1.0)
    server2 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=1)
    print("server 2: connected!")
    server2.join() # Block

# daemon=True so that these processes will definitely be killed
# when the parent process restarts
p1 = Process(target=s1, daemon=True)
p2 = Process(target=s2, daemon=True)

Server 1 joins the cluster immediately, but server 2 waits a little while before connecting. The result is shown below.

p1.start()
p2.start()
server 2: 3 seconds left before connecting...
server 1: running no-op...
server 2: 2 seconds left before connecting...
server 2: 1 seconds left before connecting...
server 2: connected!
server 1: no-op run!

As can be seen, any attempt to run an operation on the cluster blocks until all servers have joined.

p1.terminate()
p2.terminate()

What happens if a server leaves the cluster?

Let's set up a cluster with two servers. Server 1 will just repeatedly try and run a no-op located on server 1. Server 2 will die after two seconds.

def s1():
    server1 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=0)
    
    with tf.device("/job:local/task:0"):
        no_op = tf.no_op()
        
    sess1 = tf.Session(server1.target)
    for _ in range(6):
        print("Server 1: about to run no-op...", end="")
        sess1.run(no_op)
        print("success!")
        sleep(1.0)

def s2():
    server2 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=1)
    sleep(2.0)
    print("Server 2 dieing...")
    
p1 = Process(target=s1, daemon=True)
p2 = Process(target=s2, daemon=True)

p1.start()
p2.start()
Server 1: about to run no-op...success!
Server 1: about to run no-op...success!
Server 2 dieing...
Server 1: about to run no-op...success!
Server 1: about to run no-op...success!
Server 1: about to run no-op...success!
Server 1: about to run no-op...success!

In the short term, it seems there's no problem, as long as the operation we're trying to run isn't on the server that's left. (I haven't tested what happens long-term.) (Update: see comment from Yaroslav Bulatov about this below.)

If the operation is on the server that leaves...

def s1():
    server1 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=0)
    
    # This time, we place the no-op on server 2,
    # which is going to leave
    with tf.device("/job:local/task:1"):
        no_op = tf.no_op()
        
    sess1 = tf.Session(server1.target)
    for _ in range(5):
        print("Server 1: about to run no-op...", end="")
        sess1.run(no_op)
        print("success!")
        sleep(1.0)
    
p1 = Process(target=s1, daemon=True)
p2 = Process(target=s2, daemon=True)

p1.start()
p2.start()
Server 1: about to run no-op...success!
Server 1: about to run no-op...success!
Server 2 dieing...

...then the attempt to run the operation blocks.

p1.terminate()
p2.terminate()

What happens if the server then comes back?

p1 = Process(target=s1, daemon=True)
p2 = Process(target=s2, daemon=True)
p1.start()
p2.start()
sleep(3.0)
# At this point, server 1 is blocked, and server 2 is dead.
print("Restarting server 2...")
p2 = Process(target=s2, daemon=True)
p2.start()
Server 1: about to run no-op...success!
Server 1: about to run no-op...success!
Server 2 dieing...
Restarting server 2...
Process Process-7:
Traceback (most recent call last):
  File "/Users/matthew/tensorflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
    return fn(*args)
  File "/Users/matthew/tensorflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
    status, run_metadata)
  File "/Users/matthew/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.AbortedError: Graph handle is not found: 0000000000000001
Server 1: about to run no-op...Server 2 dieing...

We get a Graph handle is not found error.

So the takeaway is that Distributed TensorFlow isn't automatically resilient to server failures. (If you are interested in fault tolerance, check out the TensorFlow Dev Summit video on Distributed TensorFlow.) (Update: also, see comments from Yaroslav Bulatov below.)

Who's responsible for initializing shared variables?

One approach is just to have all workers run tf.global_variables_initializer().

But if we want to keep it clean and have only one server do initialization, we could run into a problem if other servers try to use the variables before they've been initialized. One option for avoiding this problem is to have the other workers wait until initialization has taken place using tf.report_uninitialized_variables.

def s1():
    server1 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=0)
    var = tf.Variable(0.0, name='var')
    sess1 = tf.Session(server1.target)
    
    print("Server 1: waiting for connection...")
    sess1.run(tf.report_uninitialized_variables())
    while len(sess1.run(tf.report_uninitialized_variables())) > 0:
        print("Server 1: waiting for initialization...")
        sleep(1.0)
    print("Server 1: variables initialized!")

def s2():
    server2 = tf.train.Server(cluster,
                              job_name="local",
                              task_index=1)
    var = tf.Variable(0.0, name='var')
    sess2 = tf.Session(server2.target)
    
    for i in range(3):
        print("Server 2: waiting %d seconds before initializing..."
              % (3 - i))
        sleep(1.0)
    sess2.run(tf.global_variables_initializer())
    
p1 = Process(target=s1, daemon=True)
p2 = Process(target=s2, daemon=True)
p1.start()
p2.start()
Server 1: waiting for connection...
Server 2: waiting 3 seconds before initializing...
Server 1: waiting for initialization...
Server 2: waiting 2 seconds before initializing...
Server 1: waiting for initialization...
Server 2: waiting 1 seconds before initializing...
Server 1: waiting for initialization...
Server 1: variables initialized!
p1.terminate()
p2.terminate()

Example

Let's put everything we've learned into one final example using multiple processes.

We'll create:

  • One parameter server, which will store a single variable, var.
  • Two worker tasks. Each worker will increment var a few times.

We'll have the parameter server print out the value of var a few times too so we can really see it changing.

import tensorflow as tf
from multiprocessing import Process
from time import sleep

cluster = tf.train.ClusterSpec({
    "worker": [
        "localhost:3333",
        "localhost:3334",
    ],
    "ps": [
        "localhost:3335"
    ]
})

def parameter_server():
    with tf.device("/job:ps/task:0"):
        var = tf.Variable(0.0, name='var')

    server = tf.train.Server(cluster,
                             job_name="ps",
                             task_index=0)
    sess = tf.Session(target=server.target)
    
    print("Parameter server: waiting for cluster connection...")
    sess.run(tf.report_uninitialized_variables())
    print("Parameter server: cluster ready!")
    
    print("Parameter server: initializing variables...")
    sess.run(tf.global_variables_initializer())
    print("Parameter server: variables initialized")
    
    for i in range(5):
        val = sess.run(var)
        print("Parameter server: var has value %.1f" % val)
        sleep(1.0)

    print("Parameter server: blocking...")
    server.join()
    

def worker(worker_n):
    with tf.device("/job:ps/task:0"):
        var = tf.Variable(0.0, name='var')
        
    server = tf.train.Server(cluster,
                             job_name="worker",
                             task_index=worker_n)
    sess = tf.Session(target=server.target)
    
    print("Worker %d: waiting for cluster connection..." % worker_n)
    sess.run(tf.report_uninitialized_variables())
    print("Worker %d: cluster ready!" % worker_n)
    
    while sess.run(tf.report_uninitialized_variables()):
        print("Worker %d: waiting for variable initialization..." % worker_n)
        sleep(1.0)
    print("Worker %d: variables initialized" % worker_n)
    
    for i in range(5):
        print("Worker %d: incrementing var" % worker_n)
        sess.run(var.assign_add(1.0))
        sleep(1.0)
    
    print("Worker %d: blocking..." % worker_n)
    server.join()

ps_proc = Process(target=parameter_server, daemon=True)
w1_proc = Process(target=worker, args=(0, ), daemon=True)
w2_proc = Process(target=worker, args=(1, ), daemon=True)
ps_proc.start()
Parameter server: waiting for cluster connection...
Parameter server: cluster ready!
Parameter server: initializing variables...
Parameter server: variables initialized
Parameter server: var has value 0.0
Parameter server: var has value 2.0
Parameter server: var has value 4.0
Parameter server: var has value 5.0
Parameter server: var has value 7.0
Parameter server: blocking...
w1_proc.start()
Worker 0: waiting for cluster connection...
Worker 0: cluster ready!
Worker 0: waiting for variable initialization...
Worker 0: variables initialized
Worker 0: incrementing var
Worker 0: incrementing var
Worker 0: incrementing var
Worker 0: incrementing var
Worker 0: incrementing var
Worker 0: blocking...
w2_proc.start()
Worker 1: waiting for cluster connection...
Worker 1: cluster ready!
Worker 1: waiting for variable initialization...
Worker 1: variables initialized
Worker 1: incrementing var
Worker 1: incrementing var
Worker 1: incrementing var
Worker 1: incrementing var
Worker 1: incrementing var
Worker 1: blocking...
for proc in [w1_proc, w2_proc, ps_proc]:
    proc.terminate()

Final words

We've looked at:

  • How to join together multiple TensorFlow execution engines (running on different processes or different machines) into a cluster, so that they can share variables.
  • How to place variables or operations on a specific server.
  • In-graph and between-graph replication.
  • What happens if we try to run operations on the cluster before all servers have connected, or after a server has left.
  • How to wait until variables have been initialized by another task in the cluster.

For more information and some more realistic examples, check out the official documentation at Distributed TensorFlow.