-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Hello! Our static bug checker has found a performance issue in scripts/global_optimal_proposal_variational.py: get_gradient_descent_function is repeatedly called in a for loop, but there is a tf.function decorated function gradient_descent defined and called in get_gradient_descent_function.
In that case, when gradient_descent is called in a loop, the function get_gradient_descent_function will create a new graph every time, and that can trigger tf.function retracing warning.
Here is the tensorflow document to support it.
Briefly, for better efficiency, it's better to use:
@tf.function
def inner():
pass
def outer():
inner() than:
def outer():
@tf.function
def inner():
pass
inner()Looking forward to your reply. Btw, I am glad to create a PR to fix it if you are too busy.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels