Skip to content

Llama3 like weight init#435

Open
le1nux wants to merge 5 commits intoimprove_data_writeout_perffrom
llama3_like_weight_init
Open

Llama3 like weight init#435
le1nux wants to merge 5 commits intoimprove_data_writeout_perffrom
llama3_like_weight_init

Conversation

@le1nux
Copy link
Member

@le1nux le1nux commented Mar 4, 2026

What does this PR do?

This PR ..

General Changes

  • ..

Breaking Changes

  • ..

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

@le1nux le1nux marked this pull request as ready for review March 4, 2026 17:55
@le1nux le1nux requested a review from AbasKhan March 4, 2026 18:52
@le1nux le1nux changed the base branch from main to improve_data_writeout_perf March 4, 2026 19:22
Comment on lines +509 to +512
config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml"
n_layer = 4
n_embd = 256
model = self._get_components(config_file_path=config_file_path, has_bias=has_bias)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please turn this into a @pytest.fixture(scope="module") and the tests into true, separate tests.

),
ComponentEntity(
"model_initialization",
"llama3_like",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"llama3_like",
"gpt2_llama3_like",

It should be clear that this component only works for out gpt2 model.


self.regex_to_init = {
# embedding weights
r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really gonna be std=1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar in function to torchtitans weight init nn.init.normal_(self.tok_embeddings.weight) , you can see it here , mean=0.0, std=1 are then unnecessary right ?, since they are the default values as per the documentation

match_count += 1
hits[weight_regex] += 1
if match_count == 0:
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a flag which turns this into an error?

b=2,
),
}
if bias:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this depend on whether the given model has biases or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to that for llama titan has bias set to False for attention linears, see here

b=2,
),
# SwiGLU
r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the switch to \w+ from \d+ here and below?


self.regex_to_init = {
# embedding weights
r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar in function to torchtitans weight init nn.init.normal_(self.tok_embeddings.weight) , you can see it here , mean=0.0, std=1 are then unnecessary right ?, since they are the default values as per the documentation

b=2,
),
# final attention projection in attention block
r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This corresponds to following right ?, but in there you can see for out projection its std=init_std , which can be intialized differently and defaults to depth_init , because here we pass weight_init_std , which default to depth_init in titan here. If we dont want depth init then it matches scaled out_projections logic when depth_init is False for titan

b=2,
),
}
if bias:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to that for llama titan has bias set to False for attention linears, see here

def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None:
super().__init__()

self.regex_to_init = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need regex patterns for attention_norm, ffn_norm, and the final lm_head_normnai ?. Something like

r"transformer\.h\.\d+\.(attention_norm|ffn_norm)\.weight": nn.init.ones_,
r"transformer\.lm_head_norm\.weight": nn.init.ones_,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants