extract only html content
This commit is contained in:
parent
89e442423c
commit
edfd16ef1d
0
backend/codegen/__init__.py
Normal file
0
backend/codegen/__init__.py
Normal file
57
backend/codegen/test_utils.py
Normal file
57
backend/codegen/test_utils.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import unittest
|
||||||
|
from codegen.utils import extract_html_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_extract_html_content_with_html_tags(self):
|
||||||
|
text = "<html><body><p>Hello, World!</p></body></html>"
|
||||||
|
expected = "<html><body><p>Hello, World!</p></body></html>"
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_extract_html_content_without_html_tags(self):
|
||||||
|
text = "No HTML content here."
|
||||||
|
expected = "No HTML content here."
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_extract_html_content_with_partial_html_tags(self):
|
||||||
|
text = "<html><body><p>Hello, World!</p></body>"
|
||||||
|
expected = "<html><body><p>Hello, World!</p></body>"
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_extract_html_content_with_multiple_html_tags(self):
|
||||||
|
text = "<html><body><p>First</p></body></html> Some text <html><body><p>Second</p></body></html>"
|
||||||
|
expected = "<html><body><p>First</p></body></html>"
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
## The following are tests based on actual LLM outputs
|
||||||
|
|
||||||
|
def test_extract_html_content_some_explanation_before(self):
|
||||||
|
text = """Got it! You want the song list to be displayed horizontally. I'll update the code to ensure that the song list is displayed in a horizontal layout.
|
||||||
|
|
||||||
|
Here's the updated code:
|
||||||
|
|
||||||
|
<html lang="en"><head></head><body class="bg-black text-white"></body></html>"""
|
||||||
|
expected = '<html lang="en"><head></head><body class="bg-black text-white"></body></html>'
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_markdown_tags(self):
|
||||||
|
text = "```html<head></head>```"
|
||||||
|
expected = "```html<head></head>```"
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_doctype_text(self):
|
||||||
|
text = '<!DOCTYPE html><html lang="en"><head></head><body></body></html>'
|
||||||
|
expected = '<html lang="en"><head></head><body></body></html>'
|
||||||
|
result = extract_html_content(text)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
14
backend/codegen/utils.py
Normal file
14
backend/codegen/utils.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_html_content(text: str):
|
||||||
|
# Use regex to find content within <html> tags and include the tags themselves
|
||||||
|
match = re.search(r"(<html.*?>.*?</html>)", text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
else:
|
||||||
|
# Otherwise, we just send the previous HTML over
|
||||||
|
print(
|
||||||
|
"[HTML Extraction] No <html> tags found in the generated content: " + text
|
||||||
|
)
|
||||||
|
return text
|
||||||
@ -2,6 +2,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
|
from codegen.utils import extract_html_content
|
||||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||||
from custom_types import InputMode
|
from custom_types import InputMode
|
||||||
from llm import (
|
from llm import (
|
||||||
@ -312,6 +313,9 @@ async def stream_code(websocket: WebSocket):
|
|||||||
|
|
||||||
print("Exact used model for generation: ", exact_llm_version)
|
print("Exact used model for generation: ", exact_llm_version)
|
||||||
|
|
||||||
|
# Strip the completion of everything except the HTML content
|
||||||
|
completion = extract_html_content(completion)
|
||||||
|
|
||||||
# Write the messages dict into a log so that we can debug later
|
# Write the messages dict into a log so that we can debug later
|
||||||
write_logs(prompt_messages, completion) # type: ignore
|
write_logs(prompt_messages, completion) # type: ignore
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user