-
Notifications
You must be signed in to change notification settings - Fork 255
Open
Description
Can you add support for enhanced prompt. I found it can get better result with enhanced prompt. And it is really very special,only useful in gemma3, other LLM not that good. PS, I2V need the mmproj file 😄
def _enhance(
self,
messages: list[dict[str, str]],
image: torch.Tensor | None = None,
max_new_tokens: int = 512,
seed: int = 42,
) -> str:
if self.processor is None:
self._init_image_processor()
text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.processor(
text=text,
images=image,
return_tensors="pt",
).to(self.model.device)
pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0
model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id)
with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]):
torch.manual_seed(seed)
outputs = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
)
generated_ids = outputs[0][len(model_inputs.input_ids[0]) :]
enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
return enhanced_prompt
def enhance_t2v(
self,
prompt: str,
max_new_tokens: int = 512,
system_prompt: str | None = None,
seed: int = 42,
) -> str:
"""Enhance a text prompt for T2V generation."""
system_prompt = system_prompt or self.default_gemma_t2v_system_prompt
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"user prompt: {prompt}"},
]
return self._enhance(messages, max_new_tokens=max_new_tokens, seed=seed)
def enhance_i2v(
self,
prompt: str,
image: torch.Tensor,
max_new_tokens: int = 512,
system_prompt: str | None = None,
seed: int = 42,
) -> str:
"""Enhance a text prompt for I2V generation using a reference image."""
system_prompt = system_prompt or self.default_gemma_i2v_system_prompt
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
],
},
]
return self._enhance(messages, image=image, max_new_tokens=max_new_tokens, seed=seed)
@functools.cached_property
def default_gemma_i2v_system_prompt(self) -> str:
return _load_system_prompt("gemma_i2v_system_prompt.txt")
@functools.cached_property
def default_gemma_t2v_system_prompt(self) -> str:
return _load_system_prompt("gemma_t2v_system_prompt.txt")
def forward(self, text: str, padding_side: str = "left") -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError("This method is not implemented for the base class")
mnector, Heliumrich and Mershl
Metadata
Metadata
Assignees
Labels
No labels