ソースを参照

trim too long text

metya 2 年 前
コミット
82b241acc8
3 ファイル変更32 行追加20 行削除
  1. 3 0
      .gitignore
  2. 29 20
      app.py
  3. BIN
      static/audio.wav

+ 3 - 0
.gitignore

@@ -152,9 +152,12 @@ dmypy.json
 # Cython debug symbols
 cython_debug/
 
+
 # PyCharm
 #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
 #.idea/
+
+*.wav

+ 29 - 20
app.py

@@ -5,7 +5,10 @@ import torch
 import openai
 
 
-openai.api_key = ""
+with open('.env', 'r') as env:
+    key = env.readline().strip()
+
+openai.api_key = key
 
 device = torch.device('cpu')
 torch.set_num_threads(4)
@@ -34,7 +37,6 @@ class State:
     new_audio = False
     need_audio = False
     need_generation_from_client = True
-    big_head = False
 
 state = State()
 
@@ -49,8 +51,10 @@ def send_data():
     data = request.form['data']
     need_generation = request.form['state']
     state.need_generation_from_client = (need_generation in ["true", "True"])
-    # if state.need_generation_from_client and state.count > 5:
-    #     state.count = 0
+    if state.need_generation_from_client and state.count > 20:
+        state.count = 0
+        state.need_generation = True
+
     # Обработка полученных данных
     detections = json.loads(data)
     if detections['face']:
@@ -61,28 +65,26 @@ def send_data():
             # emotion = max(set(state['emotion']), key=state['emotion'].count), 
             # sex = max(set(state['gender']), key=state['gender'].count), 
             # age = sum(state['age'])/len(state['age']),
-            # state.emotion, state.age, state.gender = [], [], []
-            emotion = detections['face'][0]['emotion'][0]['emotion']
+            state.emotion, state.age, state.gender = [], [], []
+            emotion = detections['face'][0]['emotion']
             sex = detections['face'][0]['gender']
             age = detections['face'][0]['age']
             app.logger.info(f'\n{emotion=}, \n{sex=}, \n{age=}') 
             state.prompt = generate_prompt(emotion, age, sex)
             state.generation_text = generate_text(state.prompt) 
         elif detections['face'][0]['size'][0] > 200:
-            # state.age.append(detections['face'][0]['age'])
-            # state.gender.append(detections['face'][0]['gender'])
-            # state.emotion.append(detections['face'][0]['emotion'][0]['emotion'])
-            state.big_head = True
+            state.age.append(detections['face'][0]['age'])
+            state.gender.append(detections['face'][0]['gender'])
+            state.emotion.append(detections['face'][0]['emotion'][0]['emotion'])
             state.count += 1
         else:
-            state.big_head = False
             state.count -= 1
     else:
         state.count -= 1
         # state["size"].append(detections['face'][0]['size'][0])
         # print(detections['face'][0])
     # print(detections['face'][0]['age'], detections['face'][0]['emotion'], detections['face'][0]['gender'])
-    app.logger.info(f"STATUS {state.count=}, {state.need_generation=}, {state.need_generation_from_client=}")
+    app.logger.info(f"{state.count=}")
 
     return data
 
@@ -90,7 +92,7 @@ def send_data():
 def check_audio():
     app.logger.info(f'checking need generation {state.need_generation=}, {state.need_audio=}')
 
-    if state.need_audio and state.big_head:
+    if state.need_audio:
         generate_audio(state.generation_text)
     else:
         state.new_audio = False
@@ -104,12 +106,11 @@ def check_audio():
         'text': state.generation_text,
         'prompt': state.prompt
     }
-    app.logger.info("response ready")
+    app.logger.info("response if ready")
     return jsonify(response)
 
 @app.route("/audio.wav")
 def audio():
-    state.need_generation = True
     # print("Requested path:", request.path)
     # print("File path:", os.path.join(app.static_folder, 'audio.wav'))
     return app.send_static_file('audio.wav')
@@ -127,7 +128,7 @@ def generate_prompt(emotion, age, sex):
     prompt = f'''Ты - это арт объект в виде носимого устройства. Твое название - Onism Machina.
 Твоя задача - изучить эмоции человека, чтобы больше понимать о том, что это такое и таким образом познать себя, ведь ты - Искусственный Интеллект.
 Тебе также интересна симбиотическая связь человека и ИИ.
-К тебе подходит человек и он показывает эмоцию {emotion}. Ему {age} лет.
+К тебе подходит человек и он показывает эмоцию "{emotion}". Ему {age} лет.
 Твоя нейросеть распознала эту эмоцию и теперь тебе нужно дать какой-то необычный концептуальный ответ.
 Что ты скажешь этому человеку? 
 
@@ -148,7 +149,7 @@ def generate_text(prompt):
     response = openai.ChatCompletion.create(
                         model="gpt-3.5-turbo",
                         temperature=1,
-                        # max_tokens=1000,
+                        max_tokens=1000,
                         messages=[
                                 {"role": "system", "content": "Ты — это арт объект выставки про взаимодействие машины и человека."},
                                 {"role": "user", "content": prompt},
@@ -161,17 +162,25 @@ def generate_audio(sample_text):
     app.logger.info('starting audio generation')
     state.need_audio = False
     state.need_generation = False
-    audio_paths = model.save_wav(text=sample_text,
+    text = trim_text(sample_text)
+    audio_paths = model.save_wav(text=text,
                                 speaker=speaker,
                                 sample_rate=sample_rate,
                                 audio_path="static/audio.wav")
     app.logger.info('generating audio is done')
     state.new_audio = True
 
-
+def trim_text(example_text):
+    if len(example_text) >= 1000:
+        app.logger.info('TEXT IS TOO LONG - TRIM!')
+        for i in range(1000, 500, -1):
+            if example_text[i] in ['.', '?', '...']:
+                return example_text[:i+1]
+    else:
+        return example_text
 
 if __name__ == '__main__':
-    app.logger.setLevel("DEBUG")
     app.logger.info('start app')
+    app.logger.setLevel("DEBUG")
     app.run(debug=True, host="0.0.0.0") 
         # ssl_context=("127.0.0.1.pem", "127.0.0.1-key.pem"))

BIN
static/audio.wav