Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/paddle/v2/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ def __init__(self, pass_id):
class EndPass(WithMetric):
"""
Event On One Pass Training Complete.
To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
in your event_handler call back
"""

def __init__(self, pass_id, evaluator):
def __init__(self, pass_id, evaluator, gm):
self.pass_id = pass_id
self.gm = gm
Copy link
Contributor

@lcy-seso lcy-seso Sep 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有没有可能把这个 gm 再封一层,只提供 getLayerOutput 方法,只接受一个layer_name作为参数。如果用户能拿到这个 gm 应该可以操作很多东西,似乎有点危险。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解也没必要再wrap好多层了,如果用户期望那到gm做“危险”的事情,那必然也需要对比较底层的hack比较清楚。因为文档只写了调用getLayerOutputs,一般用户也不会调用其他的方法。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯~ 明白了。

WithMetric.__init__(self, evaluator)


Expand All @@ -73,10 +76,13 @@ def __init__(self, pass_id, batch_id):
class EndIteration(WithMetric):
"""
Event On One Batch Training Complete.
To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
in your event_handler call back
"""

def __init__(self, pass_id, batch_id, cost, evaluator):
def __init__(self, pass_id, batch_id, cost, evaluator, gm):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
self.gm = gm
WithMetric.__init__(self, evaluator)
9 changes: 7 additions & 2 deletions python/paddle/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,18 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None):
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
evaluator=batch_evaluator))
evaluator=batch_evaluator,
gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()

self.__parameter_updater__.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
event_handler(
v2_event.EndPass(
pass_id,
evaluator=pass_evaluator,
gm=self.__gradient_machine__))
self.__gradient_machine__.finish()

def test(self, reader, feeding=None):
Expand Down