问题:TensorFlow保存到文件中/从文件中加载图形

根据到目前为止的经验,有几种不同的方法可以将TensorFlow图转储到文件中,然后再将其加载到另一个程序中,但是我无法找到关于它们如何工作的清晰示例/信息。我已经知道的是:

  1. 使用a将模型的变量保存到检查点文件(.ckpt)中,tf.train.Saver()并在以后还原它们(
  2. 将模型保存到.pb文件,然后使用tf.train.write_graph()tf.import_graph_def()source)将其加载回
  3. 从.pb文件加载模型,对其进行重新训练,然后使用Bazel将其转储到新的.pb文件中(
  4. 冻结图形以将图形和权重保存在一起(
  5. 使用as_graph_def()保存模型,并为权重/变量,它们映射到常数(

但是,我无法清除有关这些不同方法的几个问题:

  1. 关于检查点文件,它们仅保存模型的训练权重吗?是否可以将检查点文件加载到新程序中并用于运行模型,还是仅将它们用作在特定时间/阶段将权重保存在模型中的方法?
  2. 关于tf.train.write_graph(),权重/变量也被保存吗?
  3. 关于Bazel,它只能保存到.pb文件中或从中加载以进行重新训练吗?是否有一个简单的Bazel命令只是将图形转储到.pb中?
  4. 关于冻结,是否可以使用来加载冻结图tf.import_graph_def()
  5. TensorFlow的Android演示从.pb文件加载到Google的Inception模型中。如果我想替换自己的.pb文件,该怎么做?我需要更改任何本机代码/方法吗?
  6. 通常,所有这些方法之间到底有什么区别?或更广泛地说,/。as_graph_def()ckpt / .pb有什么区别?

简而言之,我正在寻找一种将图形(如各种操作等)及其权重/变量都保存到文件中的方法,然后可以将其用于将图形和权重加载到另一个程序中,以供使用(不一定要继续/训练)。

关于此主题的文档不是很简单,因此,非常感谢您提供任何答案/信息。

From what I’ve gathered so far, there are several different ways of dumping a TensorFlow graph into a file and then loading it into another program, but I haven’t been able to find clear examples/information on how they work. What I already know is this:

  1. Save the model’s variables into a checkpoint file (.ckpt) using a tf.train.Saver() and restore them later (source)
  2. Save a model into a .pb file and load it back in using tf.train.write_graph() and tf.import_graph_def() (source)
  3. Load in a model from a .pb file, retrain it, and dump it into a new .pb file using Bazel (source)
  4. Freeze the graph to save the graph and weights together (source)
  5. Use as_graph_def() to save the model, and for weights/variables, map them into constants (source)

However, I haven’t been able to clear up several questions regarding these different methods:

  1. Regarding checkpoint files, do they only save the trained weights of a model? Could checkpoint files be loaded into a new program, and be used to run the model, or do they simply serve as ways to save the weights in a model at a certain time/stage?
  2. Regarding tf.train.write_graph(), are the weights/variables saved as well?
  3. Regarding Bazel, can it only save into/load from .pb files for retraining? Is there a simple Bazel command just to dump a graph into a .pb?
  4. Regarding freezing, can a frozen graph be loaded in using tf.import_graph_def()?
  5. The Android demo for TensorFlow loads in Google’s Inception model from a .pb file. If I wanted to substitute my own .pb file, how would I go about doing that? Would I need to change any native code/methods?
  6. In general, what exactly is the difference between all these methods? Or more broadly, what is the difference between as_graph_def()/.ckpt/.pb?

In short, what I’m looking for is a method to save both a graph (as in, the various operations and such) and its weights/variables into a file, which can then be used to load the graph and weights into another program, for use (not necessarily continuing/retraining).

Documentation about this topic isn’t very straightforward, so any answers/information would be greatly appreciated.


回答 0

有很多方法可以解决在TensorFlow中保存模型的问题,这可能会使它有些混乱。依次处理您的每个子问题:

  1. 检查点文件(例如产生通过调用一个上对象)只包含的权重,并且在相同程序中定义的任何其它变量。要在另一个程序中使用它们,您必须重新创建关联的图形结构(例如,通过运行代码以再次构建它,或调用),这告诉TensorFlow如何处理这些权重。请注意,调用saver.save()还会生成一个包含的文件MetaGraphDef,该文件包含一个图形以及如何将检查点的权重与该图形相关联的详细信息。有关更多详细信息,请参见教程

  2. tf.train.write_graph()只写图结构;不是重量。

  3. Bazel与读取或写入TensorFlow图无关。(也许我误会了您的问题:请随时在评论中予以澄清。)

  4. 冻结的图可以使用加载tf.import_graph_def()。在这种情况下,权重(通常)嵌入在图形中,因此您无需加载单独的检查点。

  5. 主要更改将是更新输入到模型中的张量的名称以及从模型中获取的张量的名称。在TensorFlow Android演示中,这将与传递给的inputNameoutputName字符串相对应TensorFlowClassifier.initializeTensorFlow()

  6. GraphDef是该程序的结构,其通常不通过训练过程而改变。检查点是训练过程状态的快照,通常在训练过程的每个步骤都会改变。结果,TensorFlow对这些类型的数据使用不同的存储格式,并且低级API提供了不同的方式来保存和加载它们。更高级别的库,如MetaGraphDef图书馆,Kerasskflow对这些机制的构建提供更加便捷的方式来保存和恢复整个模型。

There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:

  1. The checkpoint files (produced e.g. by calling on a object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling ), which tells TensorFlow what to do with those weights. Note that calling saver.save() also produces a file containing a MetaGraphDef, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorial for more details.

  2. tf.train.write_graph() only writes the graph structure; not the weights.

  3. Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)

  4. A frozen graph can be loaded using tf.import_graph_def(). In this case, the weights are (typically) embedded in the graph, so you don’t need to load a separate checkpoint.

  5. The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the inputName and outputName strings that are passed to TensorFlowClassifier.initializeTensorFlow().

  6. The GraphDef is the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as the MetaGraphDef libraries, Keras, and skflow build on these mechanisms to provide more convenient ways to save and restore an entire model.


回答 1

您可以尝试以下代码:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)

You can try the following code:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。