Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feature(wrh): add harmony dream in unizero #255

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ruiheng123
Copy link
Contributor

@ruiheng123 ruiheng123 commented Jul 31, 2024

@puyuan1996 puyuan1996 added the enhancement New feature or request label Aug 5, 2024
@@ -164,19 +176,60 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg
self.loss_total = torch.tensor(0., device=device)
for k, v in kwargs.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Define a dictionary for loss weights and harmony_s variables
loss_weights = {
    'loss_obs': (self.obs_loss_weight, 'loss_obs_harmony_s'),
    'loss_rewards': (self.reward_loss_weight, 'loss_rewards_harmony_s'),
    'loss_policy': (self.policy_loss_weight, 'loss_policy_harmony_s'),
    'loss_value': (self.value_loss_weight, 'loss_value_harmony_s'),
    'loss_ends': (self.ends_loss_weight, 'loss_ends_harmony_s'),
    'latent_recon_loss': (self.latent_recon_loss_weight, 'latent_recon_loss_harmony_s'),
    'perceptual_loss': (self.perceptual_loss_weight, 'perceptual_loss_harmony_s')
}

# Iterate through kwargs to process the losses
for k, v in kwargs.items():
    if k in loss_weights:
        weight, harmony_var_name = loss_weights[k]
        harmony_s = globals().get(harmony_var_name)  # Get the harmony_s variable by name

        if harmony_s_dict is None:
            self.loss_total += weight * v
        elif harmony_s is not None:
            self.loss_total += (v / torch.exp(harmony_s)) + torch.log(torch.exp(harmony_s) + 1)
        else:
            self.loss_total += weight * v

)


# else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants