Skip to content

Clarification on the diagram #1

@huseinzol05

Description

@huseinzol05

Based on the diagram, https://github.com/Zyphra/tree_attention/blob/main/images/tree.png

I saw Q1, Q2, ... Qn, but I believe all those Qn are the same Q right Q1 == Q1+n? it just been replicated to each devices based on your test code,

def make_data(shape):
    B, nh, T, C = shape
    k1, k2, k3 = rand.split(rand.PRNGKey(0), 3)
    Q = rand.normal(k1, (B, 1, nh, C)).astype(jnp.float16)
    K = lax.with_sharding_constraint(
            rand.normal(k2, (B, T, nh, C)).astype(jnp.float16), NamedSharding(mesh, seq_spec)
    )
    V = lax.with_sharding_constraint(
            rand.normal(k3, (B, T, nh, C)).astype(jnp.float16), NamedSharding(mesh, seq_spec)
    )
    return Q, K, V

But I believe Kn or Vn are different, K1 != K1+n or V1 != V1+n. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions