Source code for deepmd.utils.parallel_op

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Optional,
    Tuple,
)

from deepmd.env import (
    tf,
)
from deepmd.utils.sess import (
    run_sess,
)


[docs]class ParallelOp: """Run an op with data parallelism. Parameters ---------- builder : Callable[..., Tuple[Dict[str, tf.Tensor], Tuple[tf.Tensor]]] returns two objects: a dict which stores placeholders by key, and a tuple with the final op(s) nthreads : int, optional the number of threads config : tf.ConfigProto, optional tf.ConfigProto Examples -------- >>> from deepmd.env import tf >>> from deepmd.utils.parallel_op import ParallelOp >>> def builder(): ... x = tf.placeholder(tf.int32, [1]) ... return {"x": x}, (x + 1) ... >>> p = ParallelOp(builder, nthreads=4) >>> def feed(): ... for ii in range(10): ... yield {"x": [ii]} ... >>> print(*p.generate(tf.Session(), feed())) [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] """ def __init__( self, builder: Callable[..., Tuple[Dict[str, tf.Tensor], Tuple[tf.Tensor]]], nthreads: Optional[int] = None, config: Optional[tf.ConfigProto] = None, ) -> None: if nthreads is not None: self.nthreads = nthreads elif config is not None: self.nthreads = max(config.inter_op_parallelism_threads, 1) else: self.nthreads = 1 self.placeholders = [] self.ops = [] for ii in range(self.nthreads): with tf.name_scope("task_%d" % ii) as scope: placeholder, op = builder() self.placeholders.append(placeholder) self.ops.append(op)
[docs] def generate( self, sess: tf.Session, feed: Generator[Dict[str, Any], None, None] ) -> Generator[Tuple, None, None]: """Returns a generator. Parameters ---------- sess : tf.Session TensorFlow session feed : Generator[dict, None, None] generator which yields feed_dict Yields ------ Generator[Tuple, None, None] generator which yields session returns """ nn = self.nthreads while True: feed_dict = {} for ii in range(self.nthreads): try: fd = next(feed) except StopIteration: if ii == 0: return nn = ii break for kk, vv in fd.items(): feed_dict[self.placeholders[ii][kk]] = vv ops = self.ops[:nn] yield from run_sess(sess, ops, feed_dict=feed_dict)