I saw Q1, Q2, ... Qn, but I believe all those Qn are the same Q right Q1 == Q1+n? it just been replicated to each devices based on your test code,
def make_data(shape):
B, nh, T, C = shape
k1, k2, k3 = rand.split(rand.PRNGKey(0), 3)
Q = rand.normal(k1, (B, 1, nh, C)).astype(jnp.float16)
K = lax.with_sharding_constraint(
rand.normal(k2, (B, T, nh, C)).astype(jnp.float16), NamedSharding(mesh, seq_spec)
)
V = lax.with_sharding_constraint(
rand.normal(k3, (B, T, nh, C)).astype(jnp.float16), NamedSharding(mesh, seq_spec)
)
return Q, K, V
But I believe Kn or Vn are different, K1 != K1+n or V1 != V1+n. Thanks!
Based on the diagram, https://github.com/Zyphra/tree_attention/blob/main/images/tree.png
I saw Q1, Q2, ... Qn, but I believe all those Qn are the same Q right Q1 == Q1+n? it just been replicated to each devices based on your test code,
But I believe Kn or Vn are different, K1 != K1+n or V1 != V1+n. Thanks!