20210801のPythonに関する記事は30件です。

Tweepyで世界各地のトレンドを調べてみた。(Python初心者)

はじめに 著者はPython初心者です。TwitterAPIに興味がある人、あまり使ったことが無い人の参考になればうれしいです。今回の記事ではTwitterAPIのアクセス権限がある前提で話を進めていくので、まだTwitterAPIで遊んだことがない人は前回の記事を参考にしていただければと思います。 https://qiita.com/tomo045/items/7aa50b45702f105bcee0 動作環境 OS→Windows10、Pythonバージョン→3.7.8、コードの実行→Jupyter Notebook(VS Code) コード import tweepy #API認証の設定 API_KEY = "*********************" API_SECRET = "*******************" ACCESS_TOKEN = "********************" ACCESS_TOKEN_SECRET = "***********************" auth = tweepy.OAuthHandler(API_KEY, API_SECRET) auth.set_access_token(ACCESS_TOKEN, ACCESS_TOKEN_SECRET) api = tweepy.API(auth) #地名とwoeidを紐づけた辞書を作成 woeid = { "New York":2459115,"London":44418,"Paris":615702,"Rome":721943, "Tokyo":1118370, "Kenya":23424863, "São Paulo":455827, "Ukraine":23424976, "Zurich":784794 } for area, wid in woeid.items(): print("--- {} ---".format(area)) trends = api.trends_place(wid)[0] for content in trends["trends"][:1] print(content['name']) トレンドを調べるためにtweepyのTrendsメソッドを使います。The Yahoo! Where On Earth IDによって場所を設定します。辞書型変数woeidに調べたい都市名(area)とThe Yahoo! Where On Earth ID(wid)を入力しましょう。woeidは以下のサイトで参照しました。 https://muftsabazaar.com/blog/category/?c=WOEID trends = api.trends_place(wid)[0] for content in trends["trends"][:1] print(content['name']) コードの中でも一番難しいのが上記の部分です。 testTrends = api.trends_place(1118370)[0] print(testTrends) trendsをわかりやすく説明するために、このコードを実行して東京のトレンドを取得してみます。 {'trends': [{'name': 'ソッサスブレイ', 'url': 'http://twitter.com/(略)', 'promoted_content': None, 'query': '%E3%(略)', 'tweet_volume': None} ... {'name': 'パーパット', 'url': 'http://twitter.com/(略)', 'promoted_content': None, 'query': '%E3%(略)', 'tweet_volume': None}], 'as_of': '2021-08-01T08:34:28Z', 'created_at': '2021-07-29T07:05:28Z', 'locations': [{'name': 'Tokyo', 'woeid': 1118370}]} 実行結果が出ました。メソッドの仕様上'trends'の要素は50個取得されるため、...で省略しました。この実行結果から、testTrendsは特殊な辞書型変数だと分かります。なぜ特殊かというと、辞書の要素がリストだからです。 content = testTrends["trends"][:1] print(content) 上のコードでtrendsキーの要素をcontentに格納します。ちなみに[:1]は50個のトレンド内容から1個だけ取り出すためのスライス操作です。 [{'name': '#ひきこもりたちでもフェスがしたい', 'url': 'http://twitter.com/(略)', 'promoted_content': None, 'query': '%23(略)', 'tweet_volume': None}] このような結果になりました。ここからnameキーの要素を取り出したいのですが、これはリストなのでなかなか取り出しづらいです。 for content in testTrends["trends"][:1]: print(content['name']) なので、for文を使ってリストtrends["trends"][:1]を辞書に変換してcontentに格納します。その後、nameキーの要素を出力させます。 #東京ドーム5個分 トレンド内容を取得できることが分かります。 結果(メインコード) 無事に世界各地のトレンドを調べることが出来ました! Excelに出力 import datetime import xlsxwriter #エクセルファイルを作成、または既存のファイルを開く wb = xlsxwriter.Workbook('trend.xlsx') #エクセルシートを作成 ws = wb.add_worksheet("sheet1") #日付の書式設定 format = wb.add_format() format.set_num_format('yy/mm/dd') #セルに書き込み ws.write(0, 0, "Place") ws.write(0, 1, "Trend") ws.write(0, 2, "Data") j = 1 for area, wid in woeid.items(): ws.write(j, 0, area) trends = api.trends_place(wid)[0] for content in trends["trends"][:1]: #トレンドをプリント ws.write(j, 1, content['name']) #日付をプリント ws.write(j, 2, datetime.date.today(), format) #改行用 j = j + 1 wb.close() このように場所、トレンド、日付が書かれた表がプログラムファイルと同じ場所に出力されます。 参照サイト https://www.pytry3g.com/entry/python-twitter-api https://kurozumi.github.io/tweepy/api.html https://muftsabazaar.com/blog/category/?c=WOEID
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

PythonでAtCoder Beginners Selection 解答と感想

プログラミングを勉強してみようと思ったけど、普通にやっても続く気がしないのでゲーム感覚で競技プログラミングをやってみようと思った。 https://atcoder.jp/contests/abs Welcome to AtCoder 題名からして最初の一問 input()で入力を受け取って、int()で数値に変換して、+で足して、print()で出力するくらいかな。 A = int(input()) BC = input() S = input() B,C = BC.split( ) print(str(A+int(B)+int(C))+" " + S) Product Welcome to AtCoderよりも簡単そう 算術演算子の+と%、if文が使えれば特に難しいことはないかな ここでmap()を覚えてコードがより短くなった A,B = map(int,input().split()) if (A*B % 2) == 0: print("Even") else: print("Odd") Placing Marbles 発想としては1が書かれたマスにビー玉を置くんだから出てくる数字を全部足せば良いと思った。 なので文字列を文字ごとに分割する方法を検索したらlist()が出てきた。 listっていうクラスがあって、それの型コンストラクタを使うと文字ごとに分割されるらしい。 クラスとか型コンストラクタとかに関してはまだよくわかっていないから後で詳しくやりたい。 A,B,C = map(int,list(input())) print(A+B+C) Shift only 全ての要素についてtrueかどうかを判定したいと思って調べたらall()とany()が出てきた。 map()を覚えて使いまくってる。ジェネレータ式というものがあるらしく、覚えたらそっちの方が簡単に書けそう。 input() A = list(map(int,input().split())) def half(n): return n/2 def evenJudge(n): return n % 2 == 0 i = 0 while True: if all(list(map(evenJudge,A))): A = list(map(half,A)) i = i + 1 else: break print(i) Coins 多重ループができれば解けるっていう問題かな。 別の場所で似たような問題を解いたことがあるけど、拡張性を考えて再起関数を使ったような記憶。 競技プログラミングでは実行速度が大事だけど、システムを作る時には拡張性や可読性を考えて作らないといけないため、どういうプログラムを書くかは考えて学んでいかないといけないと思った。 A = int(input()) B = int(input()) C = int(input()) X = int(input()) n=0 for i in range(A+1): for j in range(B+1): for k in range(C+1): if i * 500 + j * 100 + k * 50 == X: n = n+1 print(n) Some Sums 条件に沿えば足していくだけかなと 条件をどう判定するか、各桁の足し算をどう行うかが肝なのかな? (もう少し綺麗に書けたらと思うんですけど何かありますかね) N,A,B = map(int,input().split()) num = 0 for i in range(N+1): if A<= sum(map(int,list(str(i)))) <= B: num = num + i print(num) Card Game for Two .sort()で大きい順に並べて、偶数番目を足して奇数番目を引く(0番目から数えて)。 if文とループを使ったけどスライスっていうものがあって偶数番目だけを取り出すのはA[::2]、奇数番目だけを取り出すのはA[1::2]とすれば良いらしい。 見たことある気はするけどよくわからなかったからスルーしてた。実際に使う場面は結構ありそうだから覚えておかねば。 N = int(input()) A = list(map(int,input().split())) A.sort() point = 0 for i in range(N): m = A.pop() if i % 2 == 0: point += m else: point -= m print(point) Kagami Mochi ソートして0番目から順番に値が大きくなっていたらカウントっていう風にした。 set()という型コンストラクタを使えば重複がないデータの集合ができるらしく、len(set(D))で良いらしい。 N = int(input()) D = [] for i in range(N): d = int(input()) D.append(d) D.sort() n = 1 for i in range(N-1): if D[i] < D[i+1]: n = n + 1 print(n) Otoshidama i,jが決まるとkが一意になるのに気づかずにWAを出しまくった。 ある程度のことはプログラムがやってくれるからと思って処理を短くする努力を全然していなかったかな。 あと、return()以外の上手いやりかたが思いつかずに関数を定義してしまったけど、普通にexit()を使えばよかった。 知識として聞いたことはあっても身についてないことが多いかなぁ。 N,Y = map(int,input().split()) def check(N,Y): for i in reversed(range(N+1)): for j in reversed(range(N+1-i)): k = N-i-j if i * 10000 + j * 5000 + k * 1000 == Y: return(str(i) + " " + str(j) + " " + str(k)) return('-1 -1 -1') print(check(N,Y)) 白昼夢 前から一致している文字列を消していけば良いかと思ったが、dreamerとdreameraseの見分けがつかなくて前からだと厳しそう。後ろからやっていけば良いのではと考える。 結果ゴリ押し感が否めない拡張性皆無の見にくいソースが出来上がってしまった。 調べてみると指定文字列で終了するか調べる.endswith()というものがあったため、それを使えばもう少し見た目の良いソースができそう。 S = input() while len(S) != 0: if S[-5:] == "dream" or S[-5:] == "erase": S = S[:-5] elif S[-6:] == "eraser": S = S[:-6] elif S[-7:] == "dreamer": S = S[:-7] else: print("NO") exit() print("YES") 書き直し 下のように書き直してみた。最初は.removesuffix()を使おうとしたんだけどREが出てしまった。 この辺り原因すらよくわかってないからその辺の知識もつけていかないといけないなぁと。 S = input() while len(S) != 0: for removeS in ["erase","dream","eraser","dreamer"]: if S.endswith(removeS): S = S[:-len(removeS)] break else: print("NO") exit() print("YES") Traveling 距離的に届くかどうかと毎秒動き続けるのでちょうどよく止まれるかどうかの2条件を確認すれば良いかなと。 同じループを2回書いてしまっているし、全部配列に入れなくても処理できるので、前の地点だけ保存しておいて、受け取ってすぐに処理するようにすればよかったと思った。 N = int(input()) TXY = [[0,0,0]] for i in range(N): TXY.append(list(map(int,input().split()))) for i in range(N): if abs(TXY[i+1][2]-TXY[i][2]) + abs(TXY[i+1][1]-TXY[i][1]) <= TXY[i+1][0] - TXY[i][0]\ and (abs(TXY[i+1][2]-TXY[i][2]) + abs(TXY[i+1][1]-TXY[i][1]))%2 == (TXY[i+1][0] - TXY[i][0]) % 2: continue else: print("No") exit() print("Yes") まとめ プログラミング言語の知識不足で色々調べなきゃいけないことは多かったが、Beginners Selectionというだけあって根本の考え方がどうしてもわからないと言ったことはなかった。 プログラム自体はそんなに時間かからずにかけたけど記事を書くのにだいぶ時間がかかってしまったなぁという感想。 記事を書く際に概念が全然理解できてないから調べて書くっていうことが多かった。(大したことは書いてないけれども) あとは単純に日本語力がないかな・・・・ 記事の書き方も書いていく上で向上していけたらなぁって感じですね。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

BERTで英検を解く

英検の大問1は、短文穴埋め問題になっています。例えば、こういう問題です。 My sister usually plays tennis (   ) Saturdays.  1. by  2. on  3. with  4. at Bob (   ) five friends to his party.  1. made  2. visited  3. invited  4. spoke 文の中の隠された部分に入るものを、選択肢の中から答える問題です。文法的な判断もあれば、文脈から意味の通りが良い単語を選ぶ問題もあります。5級から1級まですべての難易度で出題される形式です。 この問題形式は、BERT (Bidirectional Encoder Representations from Transformers)の学習アルゴリズム(のうちの1つ)とよく似ています。ということは、事前学習済みのBERTモデルで英検の問題は解けるのではないか、ということで実際に解いてみました。 * BERTの代表的な学習方法の1つに Masked Language Modelというものがあります。これは、実在する文章の一部をわざと隠したものを用意して、それをモデルに推定させるというものです。このタスクに適合していく過程で、モデルは前後の文脈を理解するようになっていくというのが手法の意図です。英検の穴埋め問題とよく似ているので、良いBERTモデルなら英検の問題にも対応できそうだ、という予想です。 実行環境はKaggleのノートブックです。また、同じコードはGistにも上げています。 BERT solves Eiken problems (Kaggle Notebook) BERT solves Eiken problems (Gist) 事前学習済みBERTモデルの読み込み huggingface/transformers ライブラリの、fill-maskパイプラインを利用します。デフォルトで、事前学習済みのRobertaモデルがロードされます。 from transformers import pipeline model = pipeline("fill-mask") # test model( "HuggingFace is creating a {} that the community uses to solve NLP tasks.".format( model.tokenizer.mask_token)) #[{'sequence': 'HuggingFace is creating a tool that the community uses to solve NLP tasks.', # 'score': 0.17927570641040802, # 'token': 3944, # 'token_str': ' tool'}, # {'sequence': 'HuggingFace is creating a framework that the community uses to solve NLP #tasks.', # 'score': 0.11349428445100784, # 'token': 7208, # 'token_str': ' framework'}, # {'sequence': 'HuggingFace is creating a library that the community uses to solve NLP #tasks.', #... パイプラインの仕様確認をしています。 mask_tokenを含んだ文字列を与えると、その穴埋め候補を出力します。 スコアの高い順に5件出力されますが、この個数は top_kオプションで指定できます。 token_str を見ると、単語の最初に空白を含んでいることがわかります。これは、空白なしで前のトークンと接続するような接尾語と区別するためのようです。 デフォルトでは既定のボキャブラリ全体から最もスコアの高いものを選びますが、候補を targetsオプションで指定することも可能です。 英検問題のデータ化 問題を、問題文 (text)、選択肢 (choices)、正答 (answer) からなる名前付きタプルで表現します。公開されている過去問から、2021年第1回の問題を各級から10問取得しました。下記は5級の問題です。 from collections import namedtuple Problem = namedtuple("Problem", "text choices answer") eiken5 = [ Problem("A: What is your {}? B: Kazumi Suzuki.", ["hour", "club", "date", "name"], "name") ,Problem("I know Judy. She can {} French very well.", ["see", "drink", "speak", "open"], "speak") ,Problem("A: Are your baseball shoes in your room, Mike? B: No, Mom. They're in my {} at school.", ["window", "shop", "locker", "door"], "locker") ,Problem("Mysister usually plays tennis {} Saturdays.", ["by", "on", "with", "at"], "on") ,Problem("My mother likes {}. She has many pretty ones in the garden.", ["sports", "movies", "schools", "flowers"], "flowers") ,Problem("Let's begin today's class. Open your textbooks to {} 22.", ["chalk", "ground", "page", "minute"], "page") ,Problem("Today is Wednesday. Tomorrow is {}.", ["Monday", "Tuesday", "Thursday", "Friday"], "Thursday") ,Problem("I usually read magazines {} home.", ["of", "on", "with", "at"], "at") ,Problem("A: It's ten o'clock, Jimmy. {} to bed. B: All right, Mom.", ["Go", "Sleep", "Do", "Sit"], "Go") ,Problem("A: Do you live {} Tokyo? B: Yes. It's a big city.", ["after", "with", "on", "in"], "in") ] 選択肢なしで解く まずは、選択肢なしで解いてみます。ここでは、問題文をfill-maskパイプラインに当てはめて、上位5件に正答が含まれていれば正解とみなします。 import pandas as pd def solve_without_choices(problems, top_k=5): inputs = [p.text.format(model.tokenizer.mask_token) for p in problems] res = model(inputs, top_k=top_k) out = [] for p, r in zip(problems, res): # suggested answers and the scores suggested = [s["token_str"].strip() for s in r] scores = [s["score"] for s in r] suggested_scores = ",".join("%s(%.3f)" % (w,s) for w, s in zip(suggested, scores)) # location of answer if p.answer in suggested: position = suggested.index(p.answer) + 1 else: position = -1 out.append((p.text, suggested_scores, position)) out = pd.DataFrame(out, columns=["problem", "scores", "answer_position"]) out["correct"] = (out["answer_position"] > 0) return out solve_without_choices(eiken5) 結果です。選択肢を与えていないにもかかわらず、第1・2候補あたりに正答が来ています。ただし、曜日を尋ねる問題で水曜日の翌日について「Friday」が第1候補に来ているのは惜しいです(上位5件に正答の「Thursday」も入っているので正解扱いにしています)。 選択肢つきで解く 次に、選択肢を与えてその中のベストを選ぶようにします。これは、選択肢を targetsオプションに指定することで可能です。実装の中で、選択肢の単語のはじめにスペースを加えることで、前のトークンとは独立の単語として扱うことを指定しています(スペースをつけないとSuffix扱いになります)。 def solve_with_choices(problems): out = [] for p in problems: text = p.text.format(model.tokenizer.mask_token) targets = [" " + c for c in p.choices] res = model(text, targets=targets) words = [s["token_str"].strip() for s in res] scores = [s["score"] for s in res] suggested_scores = ",".join("%s(%.3f)" % (w,s) for w, s in zip(words, scores)) # location of answer if p.answer in words: position = words.index(p.answer) + 1 else: position = -1 out.append((p.text, suggested_scores, position)) out = pd.DataFrame(out, columns=["problem", "scores", "answer_position"]) out["correct"] = (out.answer_position == 1) return out solve_with_choices(eiken5) 選択肢を与えると、1つを除いて正答を1番にあげるようになりました。 やはり水曜日の翌日をこたえる問題は「Friday」が選ばれてしまい残念ながら不正解となりました。 級が低いうちは完璧ではないもの概ね正解が得られるのですが、難しくなるとだんだん選択肢の単語が既定の辞書に含まれないケースが増えてきます。そういう場合、この方式では判定できなくなってしまいます。 例えば、こちらは準1級の問題です。 p = Problem( "Some say the best way to overcome a {} is to expose oneself to what one fears. For example, people who are afraid of mice should try holding one.", ["temptation", "barricade", "phobia", "famine"], "phobia") solve_with_choices([p]) #The specified target token ` barricade` does not exist in the model vocabulary. Replacing with `Ġbarric`. #The specified target token ` phobia` does not exist in the model vocabulary. Replacing with `Ġph`. 「(  )を克服するには、実際に恐れている対象に身をさらすとよい。」という問題で、答えは "phobia (恐怖症)" なのですが、これが辞書に登録されていないため正解を得ることができません。 1級では、10問中5問で正答が辞書に含まれていませんでした。私を含めて、多くの人の辞書に eiken1 = [ Problem("Cell phones have become a permanent {} in modern society. Most perople could not imagine living without one.", ["clasp", "stint", "fixture", "rupture"], "fixture") ,Problem("Colin did not have enough money to pay for the car all at onece, so he paid it off in {} of $800 a month for two years.", ["dispositions", "installments", "enactments", "speculations"], "installments") ,Problem("When she asked her boss for a raise, Melanie's {} tone of voice made it obvious how nervous she was.", ["garish", "jovial", "pompous", "diffident"], "diffident") ,Problem("The religious sect established a {} in a rural area where its followers could live together and share everything. No private property was allowed.", ["dirge", "prelude", "repository", "commune"], "commune") ,Problem("The famous reporter was fired for {} another journalist's work. His article was almost exactly the same as that of the other journalist.", ["alleviating", "plagiarizing", "inoculating", "beleaguering"], "plagiarizing") ,Problem("Now that the local steel factory has closed down, the streets of the once-busy town are lined with {} businesses. Most owners have abandoned their stores.", ["rhetorical", "volatile", "defunct", "aspiring"], "defunct") ,Problem("The ambassador's failure to attend the ceremony held in honor of the king was considered an {} by his host nation and made already bad relations worse.", ["elucidation", "affront", "impasse", "ultimatum"], "affront") ,Problem("US border guards managed to {} the escaped prisoner as he tried to cross into Canada. He was returned to jail immediately.", ["apprehend", "pillage", "exalt", "acclimate"], "apprehend") ,Problem("Anthony enjoyed his first day at his new job. The atmosphere was {}, and his colleagues did their best to make him feel welcome.", ["congenial", "delirious", "measly", "implausible"], "congenial") ,Problem(("A: I just learned I've been {} to second violin in the school orchestra. I knew I should've practiced more." "B: Well, if you work hard, I'm sure you can get your previous position back."), ["relegated", "jeopardized", "reiterated", "stowed"], "relegated") ] solve_with_choices(eiken1) 辞書に含まれていれば正解できているのですが、知らない場合は判定不能になってしまいます。   Perplexityの比較で解く 未登録の単語に対応するために、Perplexity指標の比較によるアプローチをとります。Perplexityは「困惑」を意味する単語ですが、特に自然言語処理においてはモデルの精度指標としてよく用いられます(困惑が少ないモデルは文章をよく理解している、という解釈)。 ある文のPerplexityスコアを計算すると、そのトークンの並びが発生する確率(を変換したもの)になります(ただし、変換のため発生確率が高いほどスコアは小さい)。つまり、その文の自然さ(≒発生確率の高さ)の評価と解釈することができそうです。そこで、各選択肢を代入した文のPerplexityをそれぞれ評価して、スコアの最も低いものを選択すれば、モデルの考える最も適切な選択肢を得ることができます。 BERTモデルの中には、文のPerplexityを計算できるものがあります(参考)。これを利用して、同じ英検の問題を解きます。 なお、ここで使うのはPerplexity計算に対応しているGPT2モデルなので、上のfill-maskタスクで使用しているものとは違うモデルです。 import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast device = "cuda" model_id = "gpt2-large" model2 = GPT2LMHeadModel.from_pretrained(model_id).to(device) tokenizer = GPT2TokenizerFast.from_pretrained(model_id) def solve_with_choices2(problems): out = [] for p in problems: texts = [p.text.format(c) for c in p.choices] res = [] # store the perplexity score for each text for t in texts: tmp = tokenizer(t, return_tensors='pt') input_ids = tmp.input_ids.to(device) with torch.no_grad(): res.append(model2(input_ids, labels=input_ids)[0].item()) res = list(zip(p.choices, res)) res.sort(key=lambda a: a[1]) scores = ",".join("%s(%.3f)" % a for a in res) answer_position = [s[0] for s in res].index(p.answer) + 1 out.append((p.text, scores, answer_position)) out = pd.DataFrame(out, columns=["problem", "scores", "answer_position"]) out["correct"] = (out.answer_position==1) return out solve_with_choices(eiken1) 1級の問題の結果です。1つを除いて正しく判定することができました。 ちなみに、間違えた問題はこのようなものです。 When she asked her boss for a raise, Melanie's {} tone of voice made it obvious how nervous she was.   1. garish  2. jovial  3. pompous  4. diffident 後の文でメラニーはナーバスだった、となっていることから4番のdiffident(自信のない)が正解ですが、モデル的にはjovial(陽気な)の方が合うのではないかと判定しています。後につづく"tone of voice"とつながりが良いと見ているのかもしれません。 総合結果 最終結果(正解数)です。単純に Masked Language Modelを使って穴埋めをする方法は、ある程度まで有効なのですが、選択肢がボキャブラリーに含まれない場合に対応できないという弱点があります。これは、特に1級のように単語の難易度が高い場合に顕著でした。一方で、文のPerplexityスコアを比較することで、より「自然」なものを選ぶ方式は、級を通じて概ね機能しているように見えます(70問中69問正解)。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

FastAPI × Stripe サブスクリプション登録を試す

はじめに サブスクリプション課金を行う際、Stripeを利用すると比較的簡単に課金が行えますが、DB上にある顧客の情報を使いたかったり、すでにサブスクリプション課金が行われている顧客に再度登録が走らないようにしたいなどの理由で、バックエンド側のAPIサーバーを絡めるケースはあるかと思います。 Stripeでサブスクリプション登録をするためには、以下3つの作業が必要になります。 1. Planを登録する 2. Customerを登録する 3. Subscriptionを登録する これら3つについて、FastAPI経由での実装を試してみました。 事前準備 FastAPI FastAPI公式に従いアプリケーションの準備を行います。 Stripe Stripe公式にてアカウント登録を行います。 登録を行うと、テスト環境が利用可能となります。 また、ダッシュボードの「商品」タブから商品登録を行います。この際料金情報の登録は不要です。 Planを登録する 以下のような形で実装しました。 main.py import stripe from pydantic import BaseModel from fastapi import FastAPI, HTTPException app = FastAPI() STRIPE_SECRET_KEY = "[Stripeダッシュボードから取得できるシークレットキー]" class Plan(BaseModel): amount: int interval: str product: str nickname: Optional[str] = None @app.post("/plans/") async def create_plan(plan: Plan): stripe.api_key = STRIPE_SECRET_KEY try: result = stripe.Plan.create( amount=plan.amount, # 1単位当たりの金額 currency="jpy", # 通貨単位(今回はJPY固定) interval=plan.interval, # 支払の周期。1ヶ月毎なら"month"など product=plan.product, # 商品のID nickname=plan.nickname # プランの名前 ) except Exception: message = "StripeのPlan登録に失敗しました" return message return f'StripeのPlan登録に成功しました。ID:{result.get("id")}' ※本来は別ファイルに分けるべき箇所もありますが、お試しなので同一ファイルに記載しています これで、Bodyに以下のような内容を詰め込んだPOSTリクエストを投げることで、「月額1000円」という名前の、月に1単位あたり1000円請求するPlanが作成されます。 { "amount": 1000, "interval": "month", "product": "[商品のID]", "nickname": "月額1000円" } Customerを登録する Customerもほぼ同様ですが、ちょっとDBを絡めています。 今回はFastAPIのチュートリアルに記載のある通りにDB(SQLite)を繋いでいます。 id,email,name,stripe_customer_id,stripe_subscription_id の5カラムを持つCustomerテーブルがあることを前提としています。 main.py <前略> class StripeCustomer(BaseModel): email: str @app.post("/customers/") async def create_customer(customer: StripeCustomer, db: Session = Depends(get_db)): # DB内のCustomer検索 db_customer = crud.get_customer_by_email(db, email=customer.email) if not db_customer: raise HTTPException(status_code=400, detail="Customerが存在しません") if db_customer.stripe_customer_id: raise HTTPException(status_code=400, detail="すでにStripeにCustomerとして登録されています") # Stripeへの登録 try: stripe.api_key = STRIPE_SECRET_KEY result = stripe.Customer.create( name=db_customer.name, email=db_customer.email ) except Exception: message = "StripeのCustomer登録に失敗しました" return message # StripeのCustomerIDをDBに反映 update_customer = schemas.UpdateCustomer(id=db_customer.id,name=db_customer.name,email=db_customer.email,stripe_customer_id=result.get('id')) return crud.update_customer(db=db, customer=update_customer) これで、Bodyに以下のような内容を詰め込んだPOSTリクエストを投げることで、 ・Customerテーブルに存在する ・Stripe側に登録されていない CustomerをStripeのCustomerに登録可能となります。 { "email": "[Customerテーブルに存在するメールアドレス]" } Subscriptionを登録する Customeとほぼ同様です。 main.py <前略> class Subscription(BaseModel): customer_email: str plan: str quantity: int @app.post("/subscriptions/") async def create_subscription(subscription: schemas.Subscription, db: Session = Depends(get_db)): # DB内のCustomer検索 db_customer = crud.get_customer_by_email(db, email=subscription.customer_email) if not db_customer: raise HTTPException(status_code=400, detail="Customerが存在しません") if db_customer.stripe_subscription_id: raise HTTPException(status_code=400, detail="すでにSubscriptionが登録されています") # Stripeへの登録 stripe.api_key = STRIPE_SECRET_KEY try: result = stripe.Subscription.create( customer=db_customer.stripe_customer_id, items=[{"plan":subscription.plan, "quantity":subscription.user_count}], # 対象PlanとPlan数 collection_method="send_invoice", # 即時決済 or 請求書送付。今回は請求書送付固定 days_until_due=30 # 支払までの日数。今回は30日固定。 ) except Exception: message = "StripeのSubscription登録に失敗しました" return message # StripeのSubscription IDをDBに反映 update_customer = schemas.UpdateCustomer(id=db_customer.id, name=db_customer.name,email=db_customer.email, stripe_subscription_id=result.get('id')) return crud.update_customer(db=db, customer=update_customer) これで、Bodyに以下のような内容を詰め込んだPOSTリクエストを投げることで、 ・Customerテーブルに存在する ・Stripe側にサブスクリプションが登録されていない 場合、そのCustomerに対するサブスクリプションをStripeに登録可能となります。 { "customer_email": "[Customerのメールアドレス]", "plan": "[PlanのID]", "quantity":[数量] } が、これだけで登録すると、ちょっと問題があるので、次で問題点の対応をしていきます。 Subscription登録の修正 現在は、Subscription登録時に指定している内容は以下のとおりです。 main.py result = stripe.Subscription.create( customer=db_customer.stripe_customer_id, items=[{"plan":subscription.plan, "quantity":subscription.user_count}], collection_method="send_invoice", days_until_due=30 ) 指定項目がこれだけだと、登録日を基準として毎月の請求が発生してしまいます。 例)2021/08/02 15:30に、1ヶ月毎のPlanでSubscriptionを登録した場合 請求1回目:2021/08/02 15:30 請求2回目:2021/09/02 15:30 請求3回目:2021/10/02 15:30 ・・・ そのため、例えば毎月15日に請求を行いたい場合、15日にサブスクリプション登録を行う必要がありますが、その運用はかなり負荷が大きい(休日に作業する羽目になったり…など)です。 そこで設定できる内容に、billing_cycle_anchorがあります。 これはサブスクリプションの請求サイクル日を決め打ちできる項目で、unix時間で設定の必要がありますが、とても便利です。 ただ、今度は登録日と次の請求サイクル日の間の請求をどうするか(日割で課金するか、お金をとらないのか)の問題が発生しますが、これもproration_behaviorという項目でコントロール可能です。 create_prorationsを指定することで日割課金、noneを指定することでお金を取らず課金発生は請求1回目から、という制御ができます。 さらに、上記proration_behaviorで日割で課金する場合、Stripeが日割での金額は自動計算してくれるものの、秒単位での日割計算を行うため、登録時刻により日割請求分の金額に誤差が発生します。 この誤差を矯正するためには、backdate_startdateという項目を設定します。 これはサブスクリプション開始日時を決め打ちできる項目です。こちらもunix時間で設定します。 請求開始日固定、日割請求を行う場合は以下のような形になるかと思います。 main.py <前略> class Subscription(BaseModel): customer_email: str plan: str quantity: int billing_cycle_anchor: str backdate_start_date: str @app.post("/subscriptions/") async def create_subscription(subscription: schemas.Subscription, db: Session = Depends(get_db)): # DB内のCustomer検索 db_customer = crud.get_customer_by_email(db, email=subscription.customer_email) if not db_customer: raise HTTPException(status_code=400, detail="Customerが存在しません") if db_customer.stripe_subscription_id: raise HTTPException(status_code=400, detail="すでにSubscriptionが登録されています") # Stripeへの登録 stripe.api_key = STRIPE_SECRET_KEY try: result = stripe.Subscription.create( customer=db_customer.stripe_customer_id, items=[{"plan":subscription.plan, "quantity":subscription.user_count}], collection_method="send_invoice", days_until_due=30, billing_cycle_anchor=convert_to_unixtime(subscription.billing_cycle_anchor), proration_behavior="create_prorations", backdate_start_date=convert_to_unixtime(subscription.backdate_start_date) ) except Exception: message = "StripeのSubscription登録に失敗しました" return message # StripeのSubscription IDをDBに反映 update_customer = schemas.UpdateCustomer(id=db_customer.id, name=db_customer.name,email=db_customer.email, stripe_subscription_id=result.get('id')) return crud.update_customer(db=db, customer=update_customer) def convert_to_unixtime(dt: str): unixtime = int(datetime.strptime(dt, "%Y-%m-%d %H:%M:%S %z").timestamp()) return unixtime この場合は、Bodyに以下のような内容を詰め込んだPOSTリクエストを投げることで ・日割請求期間:2021/08/01 15:00〜2021/08/15 15:00(14日分) ・日割請求額:Plan金額 × 数量 × 14/31 ・(満額での)初回請求日:2021/08/15 15:00 ・請求2回目:2021/09/15 15:00 という形になります。 { "customer_email": "[Customerのメールアドレス]", "plan": "[PlanのID]", "quantity":[数量] "billing_cycle_anchor":"2021-08-15 15:00:00 +0900", "backdate_start_date":"2021-08-01 15:00:00 +0900" } さいごに もともと弊社(株式会社hokan)ではDjango × Stripeで決済周りは行っているのですが、せっかくなのでFastAPI × Stripeを試してみました。 どちらかというとStripeの仕様の理解が大変なように思います(日割計算が秒単位とは…!) 参考 FastAPI公式 Stripe APIリファレンス Stripe Docs 請求サイクル Stripe Docs 日割り計算 Stripe Billing 101
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【python】テーブル定義を Backlog 記法の Markdown 形式で出力するスクリプトを作った

MySQL のテーブル定義を backlog の wiki にまとめる作業が発生したのですが、テーブル数が数百もあり1つずつ手作業でまとめていくのは大変なので、自動的に backlog 記法形式でテーブル定義を作成する python スクリプトを作成しました。 準備 ツールは GitHub にあがっていますので、 clone します。 https://github.com/kiyo27/export-table-definition また、dockerを使用しているので、開発マシンに docker がインストールされている必要があります。 テーブル定義をエクスポート docker runコマンドでdocker コンテナを起動します。コンテナ起動時にエクスポート対象のデータベース情報を環境変数で渡します。 DATABASE: データベース名 USER: ユーザー名 PASSWORD: パスワード PORT: ポート番号 docker run -e DATABASE=blog -e USER=root -e PASSWORD=root -e PORT=3306 -v ${PWD}/tsv:/tmp -v ${PWD}/script:/script mysql:5.7 sh /script/script.sh コンテナを起動すると、引数に渡したスクリプトが実行されてtsvディレクトリにテーブル定義とインデックス情報が tsv 形式で出力されます。 . ├── README.md ├── definition.md ├── script │   ├── list.txt │   └── script.sh ├── sql-md.py └── tsv ├── articles_desc.tsv ├── articles_index.tsv ├── articles_tags_desc.tsv ├── articles_tags_index.tsv ├── tags_desc.tsv ├── tags_index.tsv ├── users_desc.tsv └── users_index.tsv Backlog 記法の Markdown ファイル作成 tsv ディレクトリにファイルが出力されたら、python スクリプトを実行して Backlog 記法の Markdown ファイルを作成します。 docker run -v ${PWD}:/app python:alpine3.10 python3 /app/sql-md.py ルートディレクトリに definition.md というファイルが作成されて、Backlog 記法の Markdown で書かれたテーブル定義が出力されています。 ** articles ''テーブル定義'' |Field|Type|Null|Key|Default|Extra|h |id|int(11)|NO|PRI|NULL|auto_increment| |user_id|int(11)|NO|MUL|NULL|| |title|varchar(255)|NO||NULL|| |slug|varchar(191)|NO|UNI|NULL|| |body|text|YES||NULL|| |published|tinyint(1)|YES||0|| |created|datetime|YES||NULL|| |modified|datetime|YES||NULL|| ''インデックス'' |Table|Non_unique|Key_name|Seq_in_index|Column_name|Collation|Cardinality|Sub_part|Packed|Null|Index_type|Comment|Index_comment|h |articles|0|PRIMARY|1|id|A|0|NULL|NULL||BTREE||| |articles|0|slug|1|slug|A|0|NULL|NULL||BTREE||| |articles|1|user_key|1|user_id|A|0|NULL|NULL||BTREE||| ** articles_tags ''テーブル定義'' |Field|Type|Null|Key|Default|Extra|h |article_id|int(11)|NO|PRI|NULL|| |tag_id|int(11)|NO|PRI|NULL|| ''インデックス'' |Table|Non_unique|Key_name|Seq_in_index|Column_name|Collation|Cardinality|Sub_part|Packed|Null|Index_type|Comment|Index_comment|h |articles_tags|0|PRIMARY|1|article_id|A|0|NULL|NULL||BTREE||| |articles_tags|0|PRIMARY|2|tag_id|A|0|NULL|NULL||BTREE||| |articles_tags|1|tag_key|1|tag_id|A|0|NULL|NULL||BTREE||| ** tags ''テーブル定義'' |Field|Type|Null|Key|Default|Extra|h |id|int(11)|NO|PRI|NULL|auto_increment| |title|varchar(191)|YES|UNI|NULL|| |created|datetime|YES||NULL|| |modified|datetime|YES||NULL|| ''インデックス'' |Table|Non_unique|Key_name|Seq_in_index|Column_name|Collation|Cardinality|Sub_part|Packed|Null|Index_type|Comment|Index_comment|h |tags|0|PRIMARY|1|id|A|0|NULL|NULL||BTREE||| |tags|0|title|1|title|A|0|NULL|NULL|YES|BTREE||| ** users ''テーブル定義'' |Field|Type|Null|Key|Default|Extra|h |id|int(11)|NO|PRI|NULL|auto_increment| |email|varchar(255)|NO||NULL|| |password|varchar(255)|NO||NULL|| |created|datetime|YES||NULL|| |modified|datetime|YES||NULL|| ''インデックス'' |Table|Non_unique|Key_name|Seq_in_index|Column_name|Collation|Cardinality|Sub_part|Packed|Null|Index_type|Comment|Index_comment|h |users|0|PRIMARY|1|id|A|0|NULL|NULL||BTREE||| あとは、このファイルの中身をコピーして wiki に貼り付けるだけです。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【Serverless Framework】ローカル環境でAPI Gateway, Lambdaを再現する

はじめに APIの開発においてAWSの「API Gateway」と「Lambda」をよく使用されていると思います。 開発する上で、毎回デプロイしてからレスポンスを確認するのは手間がかかり効率的でないということで「Serverless Framework」およびそのプラグインである「serverless offline」を使用してローカルで開発できる環境を作ります。 Serverless Framework とは サーバーレスなアプリケーションを容易に作成、管理、デプロイできるオープンソースなフレームワーク(Node.js製)。 AWSだけでなく、AzureやGCPなど様々なクラウドサービスに対応している。 serverless.ymlに各種設定を定義する。 serverless offline とは Serverless Frameworkで使えるプラグイン ローカル環境でAPI Gateway + Lambda の処理を再現してくれる 実施内容 本番環境でAPI GatewayへのリクエストをトリガーにしてLambdaで定義した関数を実行する ことを想定し、開発環境をローカルで構築する。 構築手順 1. Serverless Framework, serverless-offlineのインストール serverless frameworkをインストールする。 # インストール $ npm install serverless # 正しくインストールされているか確認 $ serverless --version => Framework Core: 2.52.1 Plugin: 5.4.3 SDK: 4.2.6 Components: 3.14.2 serverless-offline プラグインをインストールする。 $ npm install --save-dev serverless-offline 2. プロジェクトの作成 sls create -t <テンプレートの名前> -n <プロジェクト名>でプロジェクトを作成します。 serverless frameworkが提供するテンプレートはこちらを参照してください。 今回はaws-python3を使用したいと思います。 任意のディレクトリに移動したあと、以下を実行します。 $ sls create -t aws-python3 -n serverlss-sample するとディレクトリ内にhandler.pyとserverless.ymlが作成されます。 $ ls => handler.py serverless.yml 3. handler.pyを確認する handler.pyにはLambdaで実行する関数を定義します。 デフォルトでは以下のようにシンプルなhello関数が定義されてます。 handler.py import json def hello(event, context): body = { "message": "Go Serverless v1.0! Your function executed successfully!", "input": event } response = { "statusCode": 200, "body": json.dumps(body) } return response 3. serverless.ymlを編集する 使用するプラグインを記述します。 functionsのhelloにeventsを追加します。 serverless.yml service: serverless-sample # プロジェクト名 frameworkVersion: '2' provider: name: aws runtime: python3.8 lambdaHashingVersion: 20201221 functions: hello: handler: handler.hello # handler.pyの関数helloを実行する # 追加 events: - http: path: /test method: get # 追加 plugins: - serverless-offline 4. 起動 以下のコマンドでプロジェクトを起動します。 するとエンドポイントが表示されます。 $ sls offline start => ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ GET | http://localhost:3000/dev/test │ │ POST | http://localhost:3000/2015-03-31/functions/hello/invocations │ │ │ └─────────────────────────────────────────────────────────────────────────┘ レスポンスを確認する エンドポイントにアクセスするとhandler.pyに定義したbodyの内容が正しく表示されます。 $ curl http://localhost:3000/dev/test => {"message": "Go Serverless v1.0! Your function executed successfully!", ............ serverless frameworkを用いて、API Gateway, Lambdaの構成をローカル環境で再現することができました。 参考
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

挿入順を覚えている連想配列

これは何? Ruby の Hash には shift があることの紹介。 と思って調べたら、 Python3 の dict にも順序があるのがあったので、それも紹介。 Ruby の Hash Ruby の Hash は挿入順を覚えている。 挿入順を利用して古いものから順に取り出せる。 以下のような感じ。 ruby h={foo:111,bar:22,hoge:3} h[:bar]=44 h[:baz]=55 s="" while ! h.empty? do s += h.shift.inspect end p s #=> "[:foo, 111][:bar, 44][:hoge, 3][:baz, 55]" 先頭の値を見る first と、先頭の値を撤去する shift はあるけど、末尾を扱う last と pop は無い。 ruby {a:1}.first #=> [:a, 1] {a:1}.last #=> NoMethodError: undefined method `last' {a:1}.shift #=> [:a, 1] {a:1}.pop #=> NoMethodError: undefined method `pop' 値の追加がないのはわかるけど、末尾の値が扱えないのは不思議。 あと、Hash の first は引数ありなし両方あるけど、shift は引数なし。 ruby {a:1,b:2}.first(2) #=> => [[:a, 1], [:b, 2]] {a:1,b:2}.shift(2) #=> ArgumentError: wrong number of arguments (given 1, expected 0) [1,2].shift(2) #=> [1, 2] これも対応してくれればいいのにと思う。 Hash 内には挿入順が記憶されているのに、 == での比較では挿入順が度外視されるので ruby x={a:1,b:2} y={b:2,a:1} p(x==y) #=> true p([x.first, y.first]) #=> [[:a, 1], [:b, 2]] このように、等値であるオブジェクトに同じメソッドを送っても違う結果になったりする。 歴史的経緯からまあそうするしかないかとも思うけどわかりにくいよね。 Hash が挿入順を覚えていて、古い順から取り出せるので、ある種の探索アルゴリズムで役に立つ。 見つかったものを Hash に挿入しつつ、先頭から取り出して消費するという流れとか。 Python3 の dict この記事を書くにあたって調べて初めて知ったんだけど、Python3 の dict にも挿入順が入っている。 Python3 d = { "foo":111, "bar":22, "hoge":3 } d["bar"] = 44 d["baz"] = 55 s="" while d: s += repr(d.popitem()) print(s) #=> ('baz', 55)('hoge', 3)('bar', 44)('foo', 111) この popitem() は、 3.7 で「LIFO 順序が保証されるようになりました。」 とのこと。そりゃ知らなかったわけだ。 collections.OrderedDict の popItem には引数があって、末尾から取るか先頭から取るかを選べるんだけど、普通の dict の popItem には引数がない。必ず末尾から。なぜなのか。しかも ruby と逆。 辞書に入っている先頭要素は Python3 d = { "foo":111, "bar":22, "hoge":3 } iter(d.items()).__next__() #=> ('foo', 111) と、items() と iter() を使えば撤去せずに取得できるけど、もっと簡単な方法がありそうな気がする。 一方、popitem() で容易に 撤去+取得 できる末尾の要素だが、これを撤去せずに取得する方法は Python3 d = { "foo":111, "bar":22, "hoge":3 } iter(reversed(d.items())).__next__() #=> ('hoge', 3) と、むしろ面倒になる。これももっと簡単な方法がありそうな気がする。 まとめ ruby Python3 先頭を撤去せずに取得 h.first iter(d.items()).__next__() 先頭を撤去して取得 h.shift 表外下記 末尾を撤去せずに取得 不可能だと思う iter(reversed(d.items())).__next__() 末尾を撤去して取得 不可能だと思う d.popitem() Python3 の「先頭を撤去して取得」は、 firstkey = iter(d).__next__() で、先頭のキーを取得 firstval = d[firstkey] で、先頭の値を取得 del[firstkey] で、先頭の要素を撤去 という三手で実施可能。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【Python】所得税の算出方法をプログラム化してみた

所得税の算出方法をプログラム化してみた フリーランスの税金を調べるついでにコードで遊んでいました。 ずっと勉強しているとアウトプットしたい欲にかられます。 特に使い道はないです。 法改正されてて情報が古くなる可能性はあります。 算出方法について オフィシャルを参考に 国税庁No.2260 所得税の税率 所得税の算定基礎となる所得金額=収入-必要経費-各種控除 これに金額に応じた税率が適用されて、さらに控除額が引かれたものが所得税になります。 さらに令和19年までは、ここから復興特別所得税として2.1%が追加で課税されます。 実装 まずは定番のものをインポートします。 作成後にグラフで確認するので、そのためのフォーマットも追加します。 import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.ticker import ScalarFormatter, PercentFormatter 関数部分 所得から経費を差し引いた数値を入力すれば所得税が返ってきます。 他に必要な情報 p:税率 bin:税率の境界値 deduction:控除額 SpecialTax:復興特別所得税の税率 ここらへんは変更に備えて定型化したい。 def CalcTax(income): p = [0.05, 0.10, 0.20, 0.23 ,0.33 , 0.40, 0.45] bin = [i*10000 for i in [0, 195, 330, 695, 900, 1800, 4000, np.inf]] deduction = [0, 97500, 427500, 636000, 1536000, 2796000, 4976000] SpecialTax = 0.021 step = len(p) ans = 0 i = 0 while i <= step: if (income >= bin[i]) and (income < bin[i+1]): ans = (income*p[i] - deduction[i]) * (1 + SpecialTax) break i+=1 return ans 所得税のビジュアライゼーション 実際に払う所得税の金額とそれが収入に占める金額(ウォレットシェア) の二つのグラフを出してます。 体裁を整えるのにset_major_formatterを使用 税率は段階的に上がっていますが、 累進課税なので実際の負担額は滑らかに上がっていきますね。 inc = np.arange(0, 10**8, 10**5) tax = [CalcTax(i) for i in inc] fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) ax1.plot(inc, tax) ax1.set_ylabel('tax') ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) ax2.plot(inc, tax/inc) ax2.set_xlabel('income') ax2.set_ylabel('tax_per_income') ax2.yaxis.set_major_formatter(PercentFormatter(1.0)) fig.tight_layout()
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

「咳カウンター」システムを作ってみた(PC側ソフトウェア編)

前回は「咳カウンター」システムの機能紹介と、記録モジュール側のハードとソフトの 内容を紹介しました。 前の記事:「咳カウンター」システムを作ってみた(機能紹介と記録モジュール編) 今回はPC側で作成した以下の2つのツールについて紹介します。 1.加速度データをグラフ化するツール 2.咳を認識して時間帯毎のヒストグラムを作成するツール 0.開発環境(2つのツールで共通です) プログラミング言語:Python3.7  開発環境はAnacondaを利用して構築し、Spyder(Pythonの統合開発環境)で開発しました。  環境構築に際しては ここ の記事を参考にしました。 データ格納フォルダとファイル名  あらかじめ記録モジュール側で保存した2つのファイルを以下のようにリネームし、  Spyderの作業ディレクトリの下に「yyyy-mm-dd」という名前のフォルダを作成して  その中へ格納しておきます。   記録モジュール側         PC側    ACL_Z000.TXT   →→→   cough_ACL_Z_yyyy-mm-dd hh-mm-ss.TXT    ACL_Z000.SDA   →→→   cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA  ここで yyyy-mm-dd hh-mm-ss は記録モジュール側で計測した日時を表します。  (2021年8月1日9時3分5秒ならば 2021-08-01 09-03-05 ) 1.加速度データをグラフ化するツール 計測した加速度データファイルを読み込んで、指定した表示スケールと表示開始時刻からの 加速度データをグラフに表示します。 表示スケールはフルスケールで1~120秒で指定が可能で、表示開始時刻は任意の時刻を 指定することができます。上の画像は異なる表示スケール(1秒、10秒、120秒)で同じ時刻 から表示した例です。青色の線は計測した加速度データ、オレンジ色の線は加速度データの DC成分、緑色の線はそれらの差分を表しています(見易さのために青色とオレンジ色の線は -1500オフセットさせています)。 また、今回は咳の判定を目的としたものなので、加速度データは加速度値(m/sec^2)への 変換は行わず、センサーから読込んだ数値のまま使用しました。 ● ソフトウエア処理の概要  ・記録モジュール部で計測したデータの読み込みと前処理を行う(2つのツールで共通)    サンプルNo.と経過時間の対応リストの読み込み    センサデータ本体の読み込み    読み込んだセンサデータにローパスフィルタ処理    サンプルNo.と経過時間の対応リストを基に各加速度データサンプルに時刻情報を設定  ・グラフ表示スケールとグラフ表示開始時刻の入力  ・表示開始時刻から表示スケールに応じた数のデータを切り出してグラフ表示用   バッファへコピー  ・グラフ表示用バッファのデータをグラフにプロット ● プログラム できるだけコメントを入れて解り易くしたつもりですが、まだ経験が浅いのでプログラムの 記述作法やアルゴリズムに不充分なところがあると思います。ご容赦ください。 ############################################################################### # # chough_data_graphplot.py # ・記録モジュール側で保存したZ軸加速度データを読込んでグラフ表示するツール # # 処理概要 # ・記録モジュール部で計測したデータの読み込みと前処理 # ・サンプルNo.と経過時間(msec)の対応リスト(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt)の読み込み # ・センサデータ本体(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA)ファイルのデータ読み込み # ・読み込んだセンサデータにローパスフィルタ処理 # ・サンプルNo.と経過時間(msec)の対応リストを基に各加速度データサンプルに時刻情報を設定 # ・グラフ表示スケールとグラフ表示開始時刻の入力 # ・グラフ表示開始時刻からグラフ表示スケールに応じた数のデータを切り出してグラフ表示用バッファへコピー # ・グラフ表示用バッファのデータをプロット # # データ格納フォルダ(chough_data_detect_report.pyと共通の内容) # ・作業ディレクトリの下に ino_data/yyyy-mm-dd/ フォルダを作成し、その中へ以下のデータを入れる #  (作業環境に応じてプログラムを書き換えて使用してください) # ・このプログラムを実行する前に、記録モジュール側で保存したファイルを予め以下のようにリネームし、 #  上記のフォルダへ格納しておいてください # ・ACL_Z***.TXT → cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt # ・ACL_Z***.SDA → cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA # ・ここでyyyy-mm-dd hh-mm-ssは計測開始の日時を表す # # 入力ファイル(chough_data_detect_report.pyと共通の内容) # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA (計測データ本体。 yyyy-mm-dd 部は日付、hh-mm-ss 部は時刻) # データ形式: 2byteの整数データ列 ([下位byte,上位byte],[下位byte,上位byte],・・・・) # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt (計測データ情報。 yyyy-mm-dd 部は日付、hh-mm-ss 部は時刻) # データ形式: S_No = nnnn Timer(msec) = tttt # : : : : # # 出力ファイル # 無し # ############################################################################### import datetime import sys import numpy as np import os import matplotlib.pyplot as plt plt.rcParams['font.family'] = "MS Gothic" # グラフ中に日本語表記するための設定 ### graph_plot(): グラフの描画関数 ######################################## # # 処理概要 # ・引数 z1 ~ z3 の各加速度データをグラフ表示する # # 引数 # draw_start_time :グラフ表示開始時刻(datetime.time型) # m :表示するデータの数 # t :各加速度データサンプルに対応する時刻情報データの配列へのポインタ # z1 ~ z3 :加速度情報データ配列へのポインタ # label_name :凡例用ラベル名データの配列へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # draw_start_time_label :グラフ表示開始時刻(datetime.time型) # xtick_start :グラフ表示開始時刻の秒(float型) # tt[] :各加速度データサンプルに対応する表示用時刻情報データの配列 # # 戻り値 # 無し # def graph_plot(draw_start_time,m,t,z1,z2,z3,label_name): draw_start_time_label = draw_start_time.replace(microsecond=0) xtick_start = draw_start_time.second + draw_start_time.microsecond/1000000 tt = [0.0] * m i=0 for i in range(m): tt[i] = xtick_start + (t[i] - t[0]).total_seconds() plt.plot(tt, z1, label = label_name[0]) # z1線プロット plt.plot(tt, z2, label = label_name[1]) # z2線プロット plt.plot(tt, z3, label = label_name[2]) # z3線プロット # グラフ設定 plt.legend(bbox_to_anchor=(1, 1.2), loc='upper right', borderaxespad=0,fontsize=6) # 凡例の表示 plt.title(f'加速度グラフ : {draw_start_time_label} ~', fontsize=10) # グラフタイトル plt.xlabel('時刻(秒)', fontsize=8) # x軸ラベル plt.ylabel('加速度(生データ)', fontsize=8) # y軸ラベル plt.xlim([int(tt[0]), int(tt[m-1] + 1)]) # x軸範囲 plt.ylim([-2500, 1000]) # y軸範囲 plt.tick_params(labelsize = 8) # 軸ラベルの目盛りサイズ plt.yticks(np.arange(-2000, 1500, 500)) # y軸の目盛りを引く場所を指定(無ければ自動で決まる) plt.grid() # グリッドの表示 plt.show() ### info_file_read(): サンプルNo.と経過時間対応リストファイルの読み込み ########## # # 処理概要 # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txtファイル からの読み込み # # 引数 # fname : 時刻補正情報データのファイル名 # co_ti : 時刻補正情報データ配列へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # count_time_tmp :1行分の読込みデータのバッファ # # 戻り値 # 時刻補正情報データの総数 # def info_file_read(fname, co_ti): file2 = open(fname, "r") data_end = 0 i = 0 while data_end == 0: count_time_tmp = file2.readline() if len(count_time_tmp) == 0: data_end = 1 break else: tmp_str = count_time_tmp.split(' ') co_ti[i] = [int(tmp_str[2]),int(tmp_str[5])] i += 1 file2.close() return i # 時刻補正情報データの総数を示す値を返す ### sda_file_read(): 計測データ本体のファイルからの読み込み #################### # # 処理概要 # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.sda ファイルからの読み込み # # 引数 # fname : 計測データ本体のファイル名 # d : 読込んだセンサデータ配列へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # data1 :読込んだセンサデータの下位バイト # data2 :読込んだセンサデータの上位バイト # # 戻り値 # 無し # def sda_file_read(fname, d): data_end = 0 i = 0 file1 = open(fname, "rb") while data_end == 0: data1 = file1.read(1) if len(data1) == 0: data_end = 1 break else: data2 = file1.read(1) d[i] = int(data2[0]) * 256 + int(data1[0]) if int(data2[0]) >= 0x80: d[i] = d[i] - 65536 i += 1 ### lpf_process():ローパスフィルタ処理 ############################## # # 処理概要 # ・ナイキスト周波数以上をカットするローパスフィルタ処理(結果をd_lpf1[]へ格納) # ・5Hz以下のDC成分のデータ配列を求める(結果をd_lpf2_av[]へ格納) #  (DC成分抽出では、ローパスフィルタの時間遅れを補正する為に、時間軸に正方向・逆方向の両方向に #  ローパスフィルタ処理をして、両者の平均値を求めた) # ・d_lpf1[]とd_lpf2_av[]の差分配列(DC成分をカットしたセンサデータ配列)を求める # # # 引数 # d_num : 対象のセンサデータ総数 # klpf1 : ナイキスト周波数以上をカットするローパスフィルタの係数 # klpf2 : DC成分カット用ローパスフィルタの係数(z_lpf_avの計算用) # d : 対象センサデータ配列へのポインタ # d_lpf1 : 全センサデータ数分のナイキスト周波数以下でローパスフィルタ後のデータの保存メモリへのポインタ # d_lpf2_av : 全センサデータ数分のd_lpf2とd_ilpf2の平均値のデータ保存メモリへのポインタ(フィルタの遅れ補正) # d_lpf_def : d_lpf1[]とd_lpf2_av[]の差分配列(DC成分をカットしたセンサデータ配列)へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # d_lpf2 : 全データ数分のz軸5Hzローパスフィルタ後のデータ保存メモリ # d_ilpf2 : 全データ数分のz軸逆方向5Hzローパスフィルタ後のデータ保存メモリ # # 戻り値 # 無し # def lpf_process(d_num, klpf1, klpf2, d, d_lpf1, d_lpf2_av, d_lpf_def): d_lpf2 = [0] * d_num # 全データ数分のz軸5Hzローパスフィルタ後のデータ保存メモリ確保と初期化 d_ilpf2 = [0] * d_num # 全データ数分のz軸逆方向5Hzローパスフィルタ後のデータ保存メモリ確保と初期化 # ローパスフィルタ処理 ナイキスト周波数(c_off1:25Hz)以上をカット d_lpf1[0] = d[0] i = 1 while i < d_num: d_lpf1[i] = int((klpf1 * d[i]) + ((1-klpf1) * d_lpf1[i-1])) i += 1 # ローパスフィルタ処理 c_off2以下の成分のみを抽出 ####### d_lpf2[0] = d_lpf1[0] i = 1 while i < d_num: d_lpf2[i] = int((klpf2 * d_lpf1[i]) + ((1-klpf2) * d_lpf2[i-1])) i += 1 # 逆方向ローパスフィルタ処理 c_off2以下の成分のみを抽出 ####### d_ilpf2[d_num - 1] = d_lpf1[d_num - 1] i = d_num - 2 while 0 <= i: d_ilpf2[i] = int((klpf2 * d_lpf1[i]) + ((1-klpf2) * d_ilpf2[i+1])) i -= 1 # z_lpf2とz_ilpf2の平均値計算(フィルタによる遅れの補正) i = 0 while i < d_num: d_lpf2_av[i] = int((d_lpf2[i] + d_ilpf2[i]) / 2) i += 1 # z軸ローパスフィルタ後(c_off2以下の成分)とz軸データ(ナイキスト周波数(c_off1:25Hz)以下の成分)との差分のデータを生成 i = 0 while i < d_num: d_lpf_def[i] = d_lpf1[i] - d_lpf2_av[i] i += 1 ### set_time_info():センサデータに時刻情報を設定 ############################## # # 処理概要 # ・ナイキスト周波数以上をカットするローパスフィルタ処理(結果をd_lpf1[]へ格納) # ・DC成分として5Hz以下のデータ配列を求める(結果をd_lpf2_av[]へ格納) # ・d_lpf1[]とd_lpf2_av[]の差分配列(DC成分をカットしたセンサデータ配列)を求める # # # 引数 # stime : 計測開始時刻(datetime.datetime型) # c_time_mun : 時刻補正情報データの総数 # co_ti : 時刻補正情報データ配列へのポインタ # d_num : 対象のセンサデータ総数 # d_time : 全データ数分の時刻データ保存メモリへのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # 特に説明が必要なもの無し # # 戻り値 # 無し # def set_time_info(stime, c_time_mun, co_ti, d_num, d_time): i = 0 while i < c_time_mun - 1: delta_t = (co_ti[i+1][1] - co_ti[i][1]) / (co_ti[i+1][0] - co_ti[i][0]) j = co_ti[i][0] j_tmp = j d_time[j] = stime + datetime.timedelta(seconds =(co_ti[i][1]/1000 - co_ti[0][1]/1000)) t_time_tmp = d_time[j] while j < co_ti[i+1][0] -1 : d_time[j+1] = t_time_tmp + datetime.timedelta(seconds =(delta_t * (j + 1 -j_tmp) /1000)) j += 1 i += 1 if co_ti[i][0] < d_num: j = co_ti[i][0] d_time[j] = stime + datetime.timedelta(seconds =(co_ti[i][1]/1000 - co_ti[0][1]/1000)) j_tmp = j j += 1 while j < d_num: d_time[j] = d_time[j_tmp] + datetime.timedelta(seconds =(delta_t * (j -j_tmp) /1000)) j += 1 ############################################################################### ### ### ### メイン処理 スタート ### ### ### ############################################################################### #### 記録モジュール部で計測したデータの読み込みと前処理 ここから >>> #################### # データファイルのフォルダとファイル名の設定 date_time = input('測定した年月日と測定開始時刻を入力してください (yyyy-mm-dd hh-mm-ss) = ') date_time_iso = date_time[:13]+':'+date_time[14:16]+':'+date_time[17:] d_f = 'ino_data/' + date_time[0:10] + '/' # 作業環境に応じて書き換えて使用してください f_sda_name = d_f + 'cough_ACL_Z_' + date_time +'.SDA' f_cough_data_info = d_f + 'cough_ACL_Z_' + date_time +'.txt' # サンプルNo.と経過時間(msec)の対応リストファイル(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt)の読み込み count_time = [[0,0]] * 1200 # 時間同期用リスト読み込みバッファ c_t_num = info_file_read(f_cough_data_info, count_time) # 各種設定値の定義とワークメモリの確保 sensor_sampling_rate = (count_time[c_t_num-1][0] * 1000) / (count_time[c_t_num-1][1] - count_time[0][1]) # センササンプリングレート算出(Hz) cut_f1 = sensor_sampling_rate / 2 # ローパスフィルタのカットオフ周波数(ナイキスト周波数) k_lpf1 = cut_f1 / sensor_sampling_rate # ナイキスト周波数以上をカットするローパスフィルタの係数 0 =< k_lpf < 1 cut_f2 = 5 # ローパスフィルタのカットオフ周波数(5Hz以下をDC成分としてカットするため → z_lpf_avの計算用) k_lpf2 = cut_f2 / sensor_sampling_rate # DC成分カット用ローパスフィルタの係数(z_lpf_avの計算用) total_data = int(os.path.getsize(f_sda_name) / 2) # 全データ数 z = [0] * total_data # z軸加速度データ格納メモリ確保と初期化 z_lpf1 = [0] * total_data # lpf1(ナイキスト周波数:25Hz)以上をカット後のデータを格納するメモリの確保と初期化 z_lpf2_av = [0] * total_data # z_lpf2とz_ilpf2の平均値のデータを格納するメモリの確保と初期化(フィルタの遅れ補正) z_lpf_def = [0] * total_data # z軸ローパスフィルタ後とz軸データとの差分のデータを格納するメモリの確保と初期化 t_time = [datetime.datetime] * total_data # 時刻データを格納するメモリの確保と初期化 # センサデータ本体(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA)ファイルのデータ読み込み sda_file_read(f_sda_name, z) # 読み込んだセンサデータにローパスフィルタ処理 lpf_process(total_data, k_lpf1, k_lpf2, z,z_lpf1, z_lpf2_av, z_lpf_def) # 各データに時刻情報を設定 stime_imput_dt = datetime.datetime.fromisoformat(date_time_iso) set_time_info(stime_imput_dt, c_t_num, count_time, total_data, t_time) stime_dt = t_time[0] etime_dt = t_time[total_data - 1] interval = etime_dt - stime_dt print(f'測定を開始した年月日 & 時刻 = {stime_dt}') print(f'測定を終了した年月日 & 時刻 = {etime_dt}') print(f'測定時間 = {interval}') #### ここまで >>> 記録モジュール部で計測したデータの読み込みと前処理 #################### ### グラフ表示スケールの設定 ここから >>> ######################################### f_key = 0 while f_key == 0: g_f_scale_str = input('グラフ横軸のフルスケール(秒)は?(1 ~ 120の数字を入力) ? = ') try: g_f_scale_tmp = int(g_f_scale_str) except ValueError: g_f_scale_tmp = 10 if 1 <= g_f_scale_tmp <= 120: f_key = 1 else: print('1 ~ 120 の数字を入力してください !!!') g_f_scale = int(g_f_scale_tmp) # g_f_scaleをint型へ変換 print(f'グラフ横軸のフルスケールは = {g_f_scale} 秒です') g_f_scale_time = datetime.timedelta(seconds=int(g_f_scale)) #### ここまで >>> グラフ表示スケールの設定 ######################################## ### グラフの表示 ここから >>> ################################################### # グラフ表示用バッファの確保 m = int(g_f_scale * sensor_sampling_rate) # グラフ描画用のバッファデータ数の設定 t_time_g = [datetime.datetime] * m # グラフ描画用の時刻データバッファの確保 z_lpf1_g = [int] * m # グラフ描画用のlpf1(ナイキスト周波数:25Hz)以上をカット後のデータ格納メモリ確保と初期化 z_lpf2_av_g = [int] * m # グラフ描画用のz_lpfとz_ilpfの平均値のデータ保存メモリ確保と初期化 z_lpf_def_g = [int] * m # グラフ描画用のz軸ローパスフィルタ後とz軸データとの差分データの保存メモリ確保と初期化 watch_time_dt = stime_dt - g_f_scale_time f_key = 1 while f_key == 1: # グラフ表示開始時刻の入力 ここから >>> #################### stime_dt_tm = stime_dt.time().replace(microsecond=0) etime_dt_tm = etime_dt.time().replace(microsecond=0) watch_time_tmp = input(f'表示開始時刻({stime_dt_tm} ~ {etime_dt_tm}) か n:次の区間/p:前の区間/e:終了 (n/p/e) を入力してください? = ') watch_time_iso = date_time[:10] + ' ' + watch_time_tmp if watch_time_tmp == 'p': watch_time_dt -= g_f_scale_time elif watch_time_tmp == 'e': f_key = 0 elif watch_time_tmp == '': watch_time_dt += g_f_scale_time else : watch_time_dt_tmp = watch_time_dt try: watch_time_dt = datetime.datetime.fromisoformat(watch_time_iso) except ValueError: print('時刻の入力が誤っています !!! HH:MM:SS のように入力してください') watch_time_dt = watch_time_dt_tmp if etime_dt <= watch_time_dt: # 表示終了位置がオーバーフローの時の処理 watch_time_dt = etime_dt - g_f_scale_time elif watch_time_dt < stime_dt: # 表示開始位置がアンダーフローの時の処理 watch_time_dt = stime_dt if f_key == 0: break # ここまで >>> グラフ表示開始時刻の入力 #################### # グラフ表示用データの切り出し ここから >>> #################### g_s_data_no_tmp = int(((watch_time_dt - stime_dt).seconds) * sensor_sampling_rate) if g_s_data_no_tmp <= 0: g_s_data_no = 0 elif total_data < g_s_data_no_tmp: g_s_data_no = total_data - int(g_f_scale_time.seconds * sensor_sampling_rate) elif t_time[g_s_data_no_tmp] < watch_time_dt : while t_time[g_s_data_no_tmp] < watch_time_dt : g_s_data_no_tmp += 1 g_s_data_no = g_s_data_no_tmp elif watch_time_dt < t_time[g_s_data_no_tmp] : while watch_time_dt < t_time[g_s_data_no_tmp] : g_s_data_no_tmp -= 1 g_s_data_no = g_s_data_no_tmp + 1 j = 0 i = 0 while i < m: # 表示対象部分のデータをバッファへコピー j = g_s_data_no + i if total_data - 1 < j: # オーバーフロー時の処理 j = total_data - 1 t_time_g[i] = t_time[j] z_lpf1_g[i] = z_lpf1[j] - 1500 # グラフの見易さのためのオフセット z_lpf2_av_g[i] = z_lpf2_av[j] - 1500 # グラフの見易さのためのオフセット z_lpf_def_g[i] = z_lpf_def[j] # i += 1 # ここまで >>> グラフ表示用データの切り出し #################### print(f'表示開始時刻 = {watch_time_dt.time()} グラフ横軸のフルスケール = {g_f_scale_time}') print(f'表示開始データのサンプル番号 = {g_s_data_no}') print(f'表示区間中の z_lpf_def の Max,Min = ({max(z_lpf_def_g[0:m - 1])},{min(z_lpf_def_g[0:m - 1])})') # グラフの描画 fig = plt.figure() draw_start_from = t_time_g[0] draw_start_time = draw_start_from.time() print(f'表示区間の開始時刻 = {draw_start_time.replace(microsecond=0)}') g_label = ['z_lpf1','z_lpf2_av','z_lpf_def'] graph_plot(draw_start_time,m,t_time_g,z_lpf1_g,z_lpf2_av_g,z_lpf_def_g,g_label) ### ここまで >>> グラフの表示 ################################################### sys.exit() 2.咳を認識して時間帯毎のヒストグラムを作成するツール 計測した加速度データファイルを読み込んで、咳波形の判定と時刻毎のヒストグラムを 作成し、PNG画像として保存します。 咳波形の判定は、振幅が閾値(現状は300に設定)を超えた波形に対して、振幅の ピーク絶対値とピーク時刻を中心に算出したFFTのF4~F16の積分値とを用いて行います。 ピーク絶対値とFFTのF4~F16積分値の組み合わせ条件は下表のように設定しました。 歩行・駆け足・階段上下などとの識別に少々苦労しましたが、咳や咳払いを正しく認識 する確率はおよそ93%、外乱を咳や咳払いとして誤認識する確率はおよそ2%程度の 精度にチューニングすることができました。咳や咳払いを見逃すより、外乱を誤認識 することを出来るだけ抑えたかったのでこのぐらいで良いかなと思っています。 ただし、この表のそれぞれの閾値は私ひとりの測定結果からチューニングしたものです ので、使用するユーザー毎にチューニングする必要があると思います。 ● ソフトウエア処理の概要  ・記録モジュール部で計測したデータの読み込みと前処理を行う(2つのツールで共通)    サンプルNo.と経過時間の対応リストの読み込み    センサデータ本体の読み込み    読み込んだセンサデータにローパスフィルタ処理    サンプルNo.と経過時間の対応リストを基に各加速度データサンプルに時刻情報を設定  ・咳のカウント処理(1次判定段階でのカウント)   ここでは振幅の閾値を超えた咳候補波形に対して波形ピーク値とFFT積分値による条件に   適合する波形を全て抽出する  ・1次判定結果に対して、ピーク時刻が近接する咳波形をひとつの咳として丸め込む   現状は0.3秒以下で近接する咳波形を1つの咳へ丸め込み、その時の咳波形のピーク値は   丸め込んだ波形の最大ピーク値、FFT積分値はその最大ピーク時刻の積分値を採用する  ・丸め込んだ咳カウント結果を基に時間毎のヒストグラムの作成と描画  ・1次判定段階でのカウント結果、近接する咳波形を丸め込んだ結果、ヒストグラムの   描画結果はそれぞれファイルへ保存 ● プログラム できるだけコメントを入れて解り易くしたつもりですが、まだ経験が浅いのでプログラムの 記述作法やアルゴリズムに不充分なところがあると思います。ご容赦ください。 ############################################################################### # # chough_data_detect_report.py # ・記録モジュール側で保存したZ軸加速度データを読込んで咳波形の判定と時刻毎のヒストグラム作成するツール # # 処理概要 # ・記録モジュール部で計測したデータの読み込みと前処理(chough_data_graphplot.pyと共通の処理) # ・サンプルNo.と経過時間(msec)の対応リスト(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt)の読み込み # ・センサデータ本体(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA)ファイルのデータ読み込み # ・読み込んだセンサデータにローパスフィルタ処理 # ・サンプルNo.と経過時間(msec)の対応リストを基に各加速度データサンプルに時刻情報を設定 # ・咳のカウント処理(1次判定段階でのカウント) #  ここでは振幅の閾値を超えた咳候補波形に対して波形ピーク値とFFT積分値による条件に適合する波形を全て抽出する # ・1次判定段階でのカウント結果に対して、近接する咳波形をひとつの咳として丸め込む # ・丸め込んだ咳カウント結果を基に時間毎のヒストグラム作成と描画 # ・1次判定段階でのカウント結果、近接する咳波形を丸め込んだ結果、ヒストグラムの描画結果はそれぞれファイルへ保存 # # データ格納フォルダ(chough_data_graphplot.pyと共通の内容) # ・作業ディレクトリの下に ino_data/yyyy-mm-dd/ フォルダを作成し、その中へ以下のデータを入れる #  (作業環境に応じてプログラムを書き換えて使用してください) # ・このプログラムを実行する前に、記録モジュール側で保存したファイルを予め以下のようにリネームし、 #  上記のフォルダへ格納しておいてください # ・ACL_Z***.TXT → cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt # ・ACL_Z***.SDA → cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA # ・ここでyyyy-mm-dd hh-mm-ssは計測開始の日時を表す # # 入力ファイル(chough_data_graphplot.pyと共通の内容) # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA (計測データ本体。 yyyy-mm-dd 部は日付、hh-mm-ss 部は時刻) # データ形式: 2byteの整数データ列 ([下位byte,上位byte],[下位byte,上位byte],・・・・) # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt (計測データ情報。 yyyy-mm-dd 部は日付、hh-mm-ss 部は時刻) # データ形式: S_No = nnnn Timer(msec) = tttt # : : : : # 出力ファイル # ・cough_detect_info_yyyy-mm-dd.txt (1次判定段階でのカウント結果のテキストファイル) # データ形式: Cough Count = 754 # Cough Detect Time : # No. = YYYY-MM-DD hh:mm:ss. c_width peak_time c_peak cough_fft_sum # 0 = 2021-07-20 10:33:37.729694 0.136189 10:33:37.768605 543 5993.22 # 1 = 2021-07-20 10:48:14.858228 0.077837 10:48:14.897146 338 8932.86 # : : : : : : : # # ・cough_detect_r_info_yyyy-mm-dd.txt (近接する咳波形を丸め込んだ結果のテキストファイル) # データ形式: (cough_detect_info_yyyy-mm-dd.txt に同じ) # # ・Cough_Count_Report_yyyy-mm-dd.png (ヒストグラム画像ファイル) # ############################################################################### import datetime import sys import numpy as np import os import matplotlib.pyplot as plt plt.rcParams['font.family'] = "MS Gothic" # グラフ中に日本語表記するための設定 ### info_file_read(): サンプルNo.と経過時間対応リストファイルの読み込み ########## # # 処理概要 # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txtファイル からの読み込み # # 引数 # fname : 時刻補正情報データのファイル名 # co_ti : 時刻補正情報データ配列へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # count_time_tmp :1行分の読込みデータのバッファ # # 戻り値 # 時刻補正情報データの総数 # def info_file_read(fname, co_ti): file2 = open(fname, "r") data_end = 0 i = 0 while data_end == 0: count_time_tmp = file2.readline() if len(count_time_tmp) == 0: data_end = 1 break else: tmp_str = count_time_tmp.split(' ') co_ti[i] = [int(tmp_str[2]),int(tmp_str[5])] i += 1 file2.close() return i # 時刻補正情報データの総数を示す値を返す ### sda_file_read(): 計測データ本体のファイルからの読み込み #################### # # 処理概要 # ・cough_ACL_Z_yyyy-mm-dd hh-mm-ss.sda ファイルからの読み込み # # 引数 # fname : 計測データ本体のファイル名 # d : 読込んだセンサデータ配列へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # data1 :読込んだセンサデータの下位バイト # data2 :読込んだセンサデータの上位バイト # # 戻り値 # 無し # def sda_file_read(fname, d): data_end = 0 i = 0 file1 = open(fname, "rb") while data_end == 0: data1 = file1.read(1) if len(data1) == 0: data_end = 1 break else: data2 = file1.read(1) d[i] = int(data2[0]) * 256 + int(data1[0]) if int(data2[0]) >= 0x80: d[i] = d[i] - 65536 i += 1 ### lpf_process():ローパスフィルタ処理 ############################## # # 処理概要 # ・ナイキスト周波数以上をカットするローパスフィルタ処理(結果をd_lpf1[]へ格納) # ・5Hz以下のDC成分のデータ配列を求める(結果をd_lpf2_av[]へ格納) #  (DC成分抽出では、ローパスフィルタの時間遅れを補正する為に、時間軸に正方向・逆方向の両方向に #  ローパスフィルタ処理をして、両者の平均値を求めた) # ・d_lpf1[]とd_lpf2_av[]の差分配列(DC成分をカットしたセンサデータ配列)を求める # # # 引数 # d_num : 対象のセンサデータ総数 # klpf1 : ナイキスト周波数以上をカットするローパスフィルタの係数 # klpf2 : DC成分カット用ローパスフィルタの係数(z_lpf_avの計算用) # d : 対象センサデータ配列へのポインタ # d_lpf1 : 全センサデータ数分のナイキスト周波数以下でローパスフィルタ後のデータの保存メモリへのポインタ # d_lpf2_av : 全センサデータ数分のd_lpf2とd_ilpf2の平均値のデータ保存メモリへのポインタ(フィルタの遅れ補正) # d_lpf_def : d_lpf1[]とd_lpf2_av[]の差分配列(DC成分をカットしたセンサデータ配列)へのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # d_lpf2 : 全データ数分のz軸5Hzローパスフィルタ後のデータ保存メモリ # d_ilpf2 : 全データ数分のz軸逆方向5Hzローパスフィルタ後のデータ保存メモリ # # 戻り値 # 無し # def lpf_process(d_num, klpf1, klpf2, d, d_lpf1, d_lpf2_av, d_lpf_def): d_lpf2 = [0] * d_num # 全データ数分のz軸5Hzローパスフィルタ後のデータ保存メモリ確保と初期化 d_ilpf2 = [0] * d_num # 全データ数分のz軸逆方向5Hzローパスフィルタ後のデータ保存メモリ確保と初期化 # ローパスフィルタ処理 ナイキスト周波数(c_off1:25Hz)以上をカット d_lpf1[0] = d[0] i = 1 while i < d_num: d_lpf1[i] = int((klpf1 * d[i]) + ((1-klpf1) * d_lpf1[i-1])) i += 1 # ローパスフィルタ処理 c_off2以下の成分のみを抽出 ####### d_lpf2[0] = d_lpf1[0] i = 1 while i < d_num: d_lpf2[i] = int((klpf2 * d_lpf1[i]) + ((1-klpf2) * d_lpf2[i-1])) i += 1 # 逆方向ローパスフィルタ処理 c_off2以下の成分のみを抽出 ####### d_ilpf2[d_num - 1] = d_lpf1[d_num - 1] i = d_num - 2 while 0 <= i: d_ilpf2[i] = int((klpf2 * d_lpf1[i]) + ((1-klpf2) * d_ilpf2[i+1])) i -= 1 # z_lpf2とz_ilpf2の平均値計算(フィルタによる遅れの補正) i = 0 while i < d_num: d_lpf2_av[i] = int((d_lpf2[i] + d_ilpf2[i]) / 2) i += 1 # z軸ローパスフィルタ後(c_off2以下の成分)とz軸データ(ナイキスト周波数(c_off1:25Hz)以下の成分)との差分のデータを生成 i = 0 while i < d_num: d_lpf_def[i] = d_lpf1[i] - d_lpf2_av[i] i += 1 ### set_time_info():センサデータに時刻情報を設定 ############################## # # 処理概要 # ・ナイキスト周波数以上をカットするローパスフィルタ処理(結果をd_lpf1[]へ格納) # ・5Hz以下をDC成分としてデータ配列を求める(結果をd_lpf2_av[]へ格納) # ・d_lpf1[]とd_lpf2_av[]の差分配列(DC成分をカットしたセンサデータ配列)を求める # # # 引数 # stime : 計測開始時刻(datetime.datetime型) # c_time_mun : 時刻補正情報データの総数 # co_ti : 時刻補正情報データ配列へのポインタ # d_num : 対象のセンサデータ総数 # d_time : 全データ数分の時刻データ保存メモリへのポインタ # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # 特に説明が必要なもの無し # # 戻り値 # 無し # def set_time_info(stime, c_time_mun, co_ti, d_num, d_time): i = 0 while i < c_time_mun - 1: delta_t = (co_ti[i+1][1] - co_ti[i][1]) / (co_ti[i+1][0] - co_ti[i][0]) j = co_ti[i][0] j_tmp = j d_time[j] = stime + datetime.timedelta(seconds =(co_ti[i][1]/1000 - co_ti[0][1]/1000)) t_time_tmp = d_time[j] while j < co_ti[i+1][0] -1 : d_time[j+1] = t_time_tmp + datetime.timedelta(seconds =(delta_t * (j + 1 -j_tmp) /1000)) j += 1 i += 1 if co_ti[i][0] < d_num: j = co_ti[i][0] d_time[j] = stime + datetime.timedelta(seconds =(co_ti[i][1]/1000 - co_ti[0][1]/1000)) j_tmp = j j += 1 while j < d_num: d_time[j] = d_time[j_tmp] + datetime.timedelta(seconds =(delta_t * (j -j_tmp) /1000)) j += 1 ### FFT(f):FFT処理関数 ######################################## # # 処理概要 # ・引数で渡されたデータ配列のFFTを計算して出力する # # 引数 # f : FFT処理の対象となるデータの配列(numpy.ndarray型) # (配列の要素数は2の累乗に限定する) # # 使用する外部変数 # 無し  # # 使用する主な内部変数 # プログラム中のコメントに記載 # # 戻り値 # F : FFT結果の配列(numpy.ndarray型) # def FFT(f): #f:サイズNの入力データ Nは2の累乗に限定する N = len(f) if N == 1: #Nが1のときはそのまま入力データを返す return f[0] f_even = f[0:N:2] #fの偶数番目の要素 f_odd = f[1:N:2] #fの奇数番目の要素 F_even = FFT(f_even) #(3)偶数番目の要素でFFT F_odd = FFT(f_odd) #(4)偶数番目の要素でFFT W_N = np.exp(-1j * (2 * np.pi * np.arange(0, N // 2)) / N) #tが0~N/2-1番目までのWを計算した配列 F = np.zeros(N, dtype ='complex') #FFTの出力 F[0:N//2] = F_even + W_N * F_odd #(9)を計算(t:0~N/2-1) F[N//2:N] = F_even - W_N * F_odd #(10)を計算(t:N/2~N-1) return F ############################################################################### ### ### ### メイン処理 スタート ### ### ### ############################################################################### # 記録モジュール部で計測したデータの読み込みと前処理 ここから >>> #################### # データファイルのフォルダとファイル名の設定 date_time = input('測定した年月日と測定開始時刻を入力してください (yyyy-mm-dd hh-mm-ss) = ') date_time_iso = date_time[:13]+':'+date_time[14:16]+':'+date_time[17:] d_f = 'ino_data/' + date_time[0:10] + '/' # 作業環境に応じて書き換えて使用してください f_sda_name = d_f + 'cough_ACL_Z_' + date_time +'.SDA' f_cough_data_info = d_f + 'cough_ACL_Z_' + date_time +'.txt' f_sda_cough_detect = d_f + 'cough_detect_info_' + date_time[0:10] +'.txt' f_sda_cough_detect_r = d_f + 'cough_detect_r_info_' + date_time[0:10] +'.txt' d_f_report = d_f # サンプルNo.と経過時間(msec)の対応リストファイル(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.txt)の読み込み count_time = [[0,0]] * 1200 # 時間同期用リスト読み込みバッファ c_t_num = info_file_read(f_cough_data_info, count_time) # 各種設定値の定義とワークメモリの確保 sensor_sampling_rate = (count_time[c_t_num-1][0] * 1000) / (count_time[c_t_num-1][1] - count_time[0][1]) # センササンプリングレート算出(Hz) cut_f1 = sensor_sampling_rate / 2 # ローパスフィルタのカットオフ周波数(ナイキスト周波数) k_lpf1 = cut_f1 / sensor_sampling_rate # ローパスフィルタの係数 0 =< k_lpf < 1 cut_f2 = 5 # ローパスフィルタのカットオフ周波数(5Hz以下をDC成分としてカットするため → z_lpf_avの計算用) k_lpf2 = cut_f2 / sensor_sampling_rate # ローパスフィルタの係数(z_lpf_avの計算用) total_data = int(os.path.getsize(f_sda_name) / 2) # 全データ数 z = [0] * total_data # z軸加速度データ格納メモリ確保と初期化 z_lpf1 = [0] * total_data # lpf1(ナイキスト周波数:25Hz)以上をカット後のデータを格納するメモリの確保と初期化 z_lpf2_av = [0] * total_data # z_lpf2とz_ilpf2の平均値のデータを格納するメモリの確保と初期化(フィルタの遅れ補正) z_lpf_def = [0] * total_data # z軸ローパスフィルタ後とz軸データとの差分のデータを格納するメモリの確保と初期化 t_time = [datetime.datetime] * total_data # 時刻データを格納するメモリの確保と初期化 # センサデータ本体(cough_ACL_Z_yyyy-mm-dd hh-mm-ss.SDA)ファイルのデータ読み込み sda_file_read(f_sda_name, z) # 読み込んだセンサデータにローパスフィルタ処理 lpf_process(total_data, k_lpf1, k_lpf2, z,z_lpf1, z_lpf2_av, z_lpf_def) # 各データに時刻情報を設定 stime_imput_dt = datetime.datetime.fromisoformat(date_time_iso) set_time_info(stime_imput_dt, c_t_num, count_time, total_data, t_time) stime_dt = t_time[0] etime_dt = t_time[total_data - 1] interval = etime_dt - stime_dt print(f'測定を開始した年月日 & 時刻 = {stime_dt}') print(f'測定を終了した年月日 & 時刻 = {etime_dt}') print(f'測定時間 = {interval}') # ここまで >>> 記録モジュール部で計測したデータの読み込みと前処理 #################### # 咳カウント処理用バッファ確保 ###################################### cough_count_max = 1000 # 咳カウントデータ数(1次判定段階でのカウント数)の最大値(バッファ確保用に定義) cough_time = [datetime.datetime] * cough_count_max # 咳波形の始点位置(閾値を超えた時刻)を格納するバッファを確保 cough_time_e = [datetime.datetime] # 咳波形がゼロレベルに戻った時点の時刻データの一時保存メモリ cough_time_p = [datetime.datetime] * cough_count_max # 咳波形のピーク時刻を格納するバッファを確保 cough_width = [datetime.timedelta] * cough_count_max # 咳波形の波形幅データを格納するバッファを確保 cough_peak = [0] * cough_count_max # 咳波形のピーク値を格納するバッファを確保 cough_peak_abs = [0] * cough_count_max # 咳波形のピーク値を格納するバッファを確保 cough_z_fft_sum = [0] * cough_count_max # 咳波形候補のf4(約6.3Hz)~f16(25Hz)の積分値を格納するバッファを確保 c_p_max_i = [0] * cough_count_max # 咳波形のローカルピーク位置を格納するバッファを確保 cough_time_r = [datetime.datetime] * cough_count_max # 咳波形が閾値を超えた時刻を格納するバッファを確保(咳波形を丸め込んだ後のデータ用) width_r = [0.0] * cough_count_max # 咳波形の波形幅データを格納するバッファを確保(咳波形を丸め込んだ後のデータ用) fft_window = 32 # FFTのウインドウサイズ(サンプリング周波数50Hzで約0.64秒) z_np = np.arange(fft_window) # FFT計算用z軸加速度元データを格納するバッファ z_fft = np.arange(fft_window) # FFT計算結果保存用バッファ # 咳のカウント処理(1次判定段階でのカウント) ここから >>> ############################ # 咳かどうかの判定にはz_lpf_def[]データを使用する # ここでは振幅の閾値を超えた咳候補波形に対して波形ピーク値とFFT積分値による条件に適合する波形を全て抽出する # (後の処理で近接する咳波形をひとつの咳として丸め込む) # 現状は振幅の閾値(threshold)を超えた波形に対して以下の条件で判定する # ・波形ピーク値がchough_peak_th1以上chough_peak_th2未満かつFFT積分値がcough_z_fft_sum_th1以上ならば咳 # ・波形ピーク値がchough_peak_th2以上かつFFT積分値がcough_z_fft_sum_th2以上ならば咳 #  と判定する # chough_peak_th1,2、cough_z_fft_sum_th1,2 の値は私一人の測定結果からチューニングした値を設定したものであり、 # 個人差がある可能性がありますので、使用する人に合わせて調整してください。 # cough_count = 0 # 咳としてカウントした数(1次判定段階でのカウント数) cc_up = 0 # 波形が閾値を超えたかどうかのフラグ:超えている時は +1(up方向) or -1(down方向)、超えていない時は0 threshold = 300 # 咳の候補として判定するための閾値 zerolevel = 50 # 波形のゼロレベルを定義(波形の開始/終了位置を判定するために使用) chough_rounding = 0.3 # 近接する咳波形をひとつの咳として丸め込むための条件(咳波形の丸め込みで使用) chough_peak_th1 = 300 # 咳としてカウントするかどうかの判定条件1(波形ピーク値条件1) chough_peak_th2 = 400 # 咳としてカウントするかどうかの判定条件2(波形ピーク値条件2) cough_z_fft_sum_th1 = 4500 # 咳としてカウントするかどうかの判定条件1(FFT積分値条件1) cough_z_fft_sum_th2 = 4000 # 咳としてカウントするかどうかの判定条件2(FFT積分値条件2) i = 0 while i < total_data: if cc_up == 0: # 対象波形の値が閾値を超えた事の判定処理 if threshold < z_lpf_def[i]: # up側の閾値を超えた時 i_back = 0 c_s_i = i while zerolevel < z_lpf_def[i-i_back-1] : # 波形がゼロレベルを超えた時点(開始位置)まで遡る i_back += 1 cough_time[cough_count] = t_time[i-i_back] # 始点位置(閾値を超えた時刻)を保存 cc_up = 1 # 波形がup方向に閾値を超えたことを示すフラグをセット elif z_lpf_def[i] < -threshold: # down側の閾値を超えた時 i_back = 0 c_s_i = i while z_lpf_def[i-i_back-1] < -zerolevel: # 波形がゼロレベルを超えた時点(開始位置)まで遡る i_back += 1 cough_time[cough_count] = t_time[i-i_back] # 始点位置(閾値を超えた時刻)を保存 cc_up = -1 # 波形がdown方向に閾値を超えたことを示すフラグをセット i += 1 elif cc_up == 1: # 対象波形の値がup方向に閾値を超えた後の処理 if z_lpf_def[i] < zerolevel: # 波形がゼロレベルに戻った時(終了位置) c_e_i = i cough_time_e = t_time[i] cough_width[cough_count] = cough_time_e - cough_time[cough_count] # 咳候補波形の波形幅を保存 cough_peak[cough_count] = max(z_lpf_def[c_s_i:c_e_i]) # 咳候補波形のピーク値を保存 c_p_i = c_s_i + z_lpf_def[c_s_i:c_e_i].index(cough_peak[cough_count]) cough_time_p[cough_count] = t_time[c_p_i] # 咳候補波形のピーク時刻を保存 # 咳候補波形のFFT解析 c_m_i = int((c_e_i + c_s_i)/2) c_fft_s_i = int(c_m_i - fft_window/2) for iii in range(fft_window): # 咳候補波形のピーク時刻を中心にfft_window分の加速度データを切り出す z_np[iii] = z_lpf_def[c_fft_s_i + iii] z_fft = FFT(z_np) z_fft_sum = sum(np.abs(z_fft[int(4 * fft_window / 32):(int(fft_window / 2) + 1)])) # 咳に特有の周波数成分を積分 cough_z_fft_sum[cough_count] = z_fft_sum # 咳に特有の周波数成分の積分値を保存 # 咳としてカウントするかどうかの判定(咳候補波形に対して波形ピーク値とFFT積分値による条件で判定) if chough_peak_th2 <= cough_peak[cough_count]: # chough_peak_th2 = 400 if cough_z_fft_sum_th2 <= cough_z_fft_sum[cough_count]: # cough_z_fft_sum_th2 = 4000 cough_count += 1 elif chough_peak_th1 <= cough_peak[cough_count]: # chough_peak_th1 = 300 if cough_z_fft_sum_th1 <= cough_z_fft_sum[cough_count]: # cough_z_fft_sum_th1 = 4500 cough_count += 1 cc_up = 0 i += 1 else: i += 1 elif cc_up == -1: # 対象波形の値がdown方向に閾値を超えた後の処理 if -zerolevel < z_lpf_def[i]: # 波形がゼロレベルに戻った時(終了位置) c_e_i = i cough_time_e = t_time[i] cough_width[cough_count] = cough_time_e - cough_time[cough_count] # 咳候補波形の波形幅を保存 cough_peak[cough_count] = min(z_lpf_def[c_s_i:c_e_i]) # 咳候補波形のピーク値を保存 c_p_i = c_s_i + z_lpf_def[c_s_i:c_e_i].index(cough_peak[cough_count]) cough_time_p[cough_count] = t_time[c_p_i] # 咳候補波形のピーク時刻を保存 # 咳候補波形のFFT解析 c_m_i = int((c_e_i + c_s_i)/2) c_fft_s_i = int(c_m_i - fft_window/2) for iii in range(fft_window): # 咳候補波形のピーク時刻を中心にfft_window分の加速度データを切り出す z_np[iii] = z_lpf_def[c_fft_s_i + iii] z_fft = FFT(z_np) z_fft_sum = sum(np.abs(z_fft[int(4 * fft_window / 32):(int(fft_window / 2) + 1)])) # 咳に特有の周波数成分を積分 cough_z_fft_sum[cough_count] = z_fft_sum # 咳に特有の周波数成分の積分値を保存 # 咳としてカウントするかどうかの判定(咳候補波形に対して波形ピーク値とFFT積分値による条件で判定) if cough_peak[cough_count] <= -chough_peak_th2: # chough_peak_th2 = 400 if cough_z_fft_sum_th2 <= cough_z_fft_sum[cough_count]: # cough_z_fft_sum_th2 = 4000 cough_count += 1 elif cough_peak[cough_count] <= -chough_peak_th1: # chough_peak_th1 = 300 if cough_z_fft_sum_th1 <= cough_z_fft_sum[cough_count]: # cough_z_fft_sum_th1 = 4500 cough_count += 1 cc_up = 0 i += 1 else: i += 1 # 咳カウント結果の表示とファイルへの保存(1次判定段階でのカウント結果) print(f'Cough Count = {cough_count}') print('Cough Detect Time :\n No. = YYYY-MM-DD hh:mm:ss. c_width peak_time c_peak cough_fft_sum') for i in range(cough_count): width = float(cough_width[i].seconds) + float(cough_width[i].microseconds)/1000000 print(f' {i: >3} = {str(cough_time[i])} {width:06f} {cough_time_p[i].time()} {cough_peak[i]:>5} {cough_z_fft_sum[i]:8.02f}') file3 = open(f_sda_cough_detect, 'w') file3.write(f'Cough Count = {cough_count}\n') file3.write('Cough Detect Time :\n No. = YYYY-MM-DD hh:mm:ss. c_width peak_time c_peak cough_fft_sum\n') for i in range(cough_count): width = float(cough_width[i].seconds) + float(cough_width[i].microseconds)/1000000 file3.write(f' {i: >3} = {str(cough_time[i])} {width:06f} {cough_time_p[i].time()} {cough_peak[i]:>5} {cough_z_fft_sum[i]:8.02f}\n') file3.close() # ここまで >>> 咳のカウント処理(1次判定段階でのカウント) ############################ # 1次判定段階でのカウント結果に対して、近接する咳波形をひとつの咳として丸め込む ここから >>> ########### # ここでは chough_rounding 以下の時間でピーク時刻が近接する咳波形を1つの咳へ丸め込む # chough_rounding は現状は0.3秒に設定しています # 咳波形のピーク値は丸め込んだ波形の最大値とし、FFT積分値はそのピーク時刻を中心としたFFT積分値を採用する for i in range(cough_count): cough_peak_abs[i] = abs(cough_peak[i]) j = 0 c_t_p_i_tmp1 = 0 for i in range(cough_count-1): cough_time_p_def = cough_time_p[i+1] - cough_time_p[i] cough_time_p_def_tmp = float(cough_time_p_def.seconds) + float(cough_time_p_def.microseconds)/1000000 if (chough_rounding < cough_time_p_def_tmp) | (i == cough_count - 2): # chough_rounding秒以下かどうかの判定 if i == cough_count - 2: c_t_p_i_tmp2 = i + 1 else : c_t_p_i_tmp2 = i if c_t_p_i_tmp1 != c_t_p_i_tmp2: cough_peak_abs_max = max(cough_peak_abs[c_t_p_i_tmp1:c_t_p_i_tmp2]) c_p_max_i[j] = c_t_p_i_tmp1 + cough_peak_abs[c_t_p_i_tmp1:c_t_p_i_tmp2].index(cough_peak_abs_max) width_r_tmp = cough_time[c_t_p_i_tmp2] - cough_time[c_t_p_i_tmp1] + cough_width[c_t_p_i_tmp2] width_r[j] = float(width_r_tmp.seconds) + float(width_r_tmp.microseconds)/1000000 cough_time_r[j] = cough_time[c_t_p_i_tmp1] else : c_p_max_i[j] = c_t_p_i_tmp2 width_r_tmp = cough_width[c_t_p_i_tmp2] width_r[j] = float(width_r_tmp.seconds) + float(width_r_tmp.microseconds)/1000000 cough_time_r[j] = cough_time[c_t_p_i_tmp1] c_t_p_i_tmp1 = i + 1 j += 1 # 近接する咳波形をひとつの咳として丸め込んだ咳カウント結果をファイルへ保存 c_p_max_i_count = j print(f'Cough Count = {c_p_max_i_count}') print('Cough Detect Time :\n No. = YYYY-MM-DD hh:mm:ss. c_width peak_time c_peak cough_fft_sum') for j in range(c_p_max_i_count): i = c_p_max_i[j] print(f' {j: >3} = {str(cough_time_r[j])} {width_r[j]:06f} {cough_time_p[i].time()} {cough_peak[i]:>5} {cough_z_fft_sum[i]:8.02f}') file4 = open(f_sda_cough_detect_r, 'w') file4.write(f'Cough Count = {c_p_max_i_count}\n') file4.write('Cough Detect Time :\n No. = YYYY-MM-DD hh:mm:ss. c_width peak_time c_peak cough_fft_sum\n') for j in range(c_p_max_i_count): i = c_p_max_i[j] file4.write(f' {j: >3} = {str(cough_time_r[j])} {width_r[j]:06f} {cough_time_p[i].time()} {cough_peak[i]:>5} {cough_z_fft_sum[i]:8.02f}\n') file4.close() # ここまで >>> 1次判定段階でのカウント結果に対して、近接する咳波形をひとつの咳として丸め込む ########### # 近接する咳波形をひとつの咳として丸め込んだ咳カウント結果を基に時間毎のヒストグラム作成と描画 cough_count_per_hour = [0] * 24 jikoku = range(0, 24) i = 0 for i in range(c_p_max_i_count): cough_hour = int(cough_time_r[i].timetuple().tm_hour) cough_count_per_hour[cough_hour] += 1 print(f' Hour = {cough_count_per_hour}') # ヒストグラムの描画 x = jikoku y = cough_count_per_hour count_max = max(y) title = '咳カウントレポート ' + str(stime_dt.date()) f_name_png = 'Cough_Count_Report_' + str(stime_dt.date()) x_label = '時刻' y_label = 'カウント' x_min = 0 x_max = 23 y_min = 0 y_max = 35 if y_max < count_max: y_max = count_max + 5 plt.bar(x, y, 0.7) # 棒グラフプロット plt.title(title, fontsize=15) # グラフタイトル plt.xlabel(x_label, fontsize=10) # x軸ラベル plt.ylabel(y_label, fontsize=10) # y軸ラベル #plt.ylim([y_min, y_max]) # y軸範囲 plt.tick_params(labelsize = 8) # 軸ラベルの目盛りサイズ plt.xticks(np.arange(x[0], x[23], 2)) # x軸の目盛りを引く場所を指定(無ければ自動で決まる) plt.yticks(np.arange(y_min, y_max, 5)) # y軸の目盛りを引く場所を指定(無ければ自動で決まる) plt.grid() # グリッドの表示 plt.savefig(d_f_report + f_name_png + '.png', format="png", dpi=300) plt.show() sys.exit() 最後に 「咳カウンター」システムを作ってみた(機能紹介と記録モジュール編) 「咳カウンター」システムを作ってみた(PC側ソフトウェア編) の2回に渡ってシステムを紹介しました。 当初の目標はほぼクリアできたと思っていますが、実際に使ってみると喉元に貼り付ける センサ部が汗などで剝がれてしまうことがあるなど、いくつかの問題も残っていますので、 引き続きレベルアップしていきたいと思います。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

「咳カウンター」システムを作ってみた(機能紹介と記録モジュール編)

1.はじめに 喘息持ちで主治医から咳や咳払いの状況を毎日記録するように言われているので、 趣味と実益を兼ねて「咳カウンター」とでも言うシステムを作ってみました。 ここでは自分の製作記録として2つの記事に分けて紹介します。 2.システム全体の構成と機能概要 システム全体は喉元に貼り付ける加速度センサ部、携帯型の記録モジュール部、そして PC側の処理とで構成されています。 記録モジュールは身に着けて使うのでできるだけ薄く小さくする為にミンティアケースに 入れる事を目標に作りました。それを100円ショップのIDカード入れに入れて首から ぶら下げて、喉元に貼り付けた加速度センサ部からのZ軸加速度データ(喉に垂直方向) をSDカードへ記録します。 PC側では測定した加速度データをグラフ化するツールと、咳を認識して時間帯毎の ヒストグラムを作成するツールをPythonにて作成しました。咳や咳払いと歩行などの 外乱振動との識別に少々苦労しましたが、自分ひとりのデータですが、咳や咳払いを 正しく認識する確率はおよそ93%、外乱振動を咳や咳払いとして誤認識する確率は およそ2%程度の精度にチューニングすることができました。咳や咳払いを見逃すより、 外乱を誤認識することを出来るだけ抑えたかったのでこのぐらいで良いかなと思っています。 3.記録モジュール 3-1.ハード構成 加速度センサ部はLIS3DH搭載の3軸加速度センサモジュールを使用し、記録モジュール部 との接続にはUSBケーブル(シールド付き)を使用しました。センサ部は喉元に貼り付けて 使うので防湿のために全体をレジンで固めました。 記録モジュールは、Seeeduino XIAO、microSDカードモジュール、リチウムイオン電池、 充電モジュールをミンティアケースに詰め込みました。 今回の製作で使用した部品の情報は以下の通りです。 1)3軸加速度センサモジュール   LIS3DH搭載 三軸加速度センサモジュール(スイッチサイエンス)   LIS3DH データシート   *喉元に貼り付けて使うので防湿のためにセンサ部全体をレジンで固めました。 2)Seeeduino XIAO(Arduino互換マイクロコントローラ)   Seeeduino XIAO(スイッチサイエンス)   Seeeduino XIAO データシート 3)SDカードリーダーライタ―   Aideepen マイクロSDストレージ拡張ボードマイクロSDカードメモリシールドモジュール   *PSI接続できるものであれば他の物でも代用できます。 4)充電モジュール   小型リチウムイオン電池充電器 USB Type -Cコネクタ搭載(スイッチサイエンス)   *ミンティアケースに入れるために電池接続側コネクタを外して使いました。 5)リチウムポリマー電池   502030-250mAhリチウムポリマー電池   *ケースに収まるものであれば他の物でも代用できます。    ただし、20時間以上計測するには250mAh程度以上は必要です。 回路図(といっても各モジュールを繋いだだけですが、、、) 記録モジュール部実配線例(ミンティアケースに入れた例です) 3-2.ソフトウェア 開発環境:Arduino IDE  (Arduino IDEのインストールとSeeeduinoのボードマネージャの追加方法は → ここ ) ソフトウェア処理の概要 ・電源SWがON後、LEDが10秒間「明」→3回「暗・明」→「暗」で計測が開始されます   → 計測開始の日時(例えば2021年7月15日の10時12分14秒)を別途メモしておきます    (後で咳の認識やグラフ表示をするデータのファイル名の一部として使用します) ・計測開始後は加速度センサ(LIS3DH)のデータを読込んでSDカードへ保存し続けます ・加速度センサのサンプリング周波数は50Hz、12bit精度、LIS3DHのFIFOに16サンプル  以上溜まったらデータをバッファメモリへ読込みます ・バッファのセンサデータが128サンプルを超えたらバイナリデータ(2byte/サンプル)  としてACL_ZXXX.SDAファイルへ追記保存します ・センサデータ保存の32回毎に累積サンプル数とプログラム起動時からの経過時間を  ACL_ZXXX.TXTファイルへ追記保存します(後の処理でサンプリング周波数のばらつき  を補正する為に使用します・・・不要かもしれませんが、、、)  また、この時のファイルへのデータ書き込み中だけLEDを「明」としており、この  LEDの点灯(約82秒毎)で正常に動作していることの確認ができます。 ・上記のACL_ZXXXの「XXX」の部分はSDカード内のNEXT_SDA.TXTファイル中に記されて  いる値となります。(詳細はソフトウェアスケッチ中のコメントに記載) ・加速度センサやSDカードへのアクセスエラーが発生した場合は、正しく計測されて  いないので、その場合はLEDが連続点滅(3回の「明・暗」と1回の「暗・暗」)して  知らせます ・電源SWのOFFまたはSDカードを抜くことで計測を終了します ソフトウェアスケッチ できるだけコメントを入れて解り易くしたつもりですが、まだ経験が浅いのでプログラムの 記述作法やアルゴリズムに不充分なところがあると思います。ご容赦ください。 /* * Cough_data_logger.ino * * 概要 * ・加速度センサのデータを読み込んでSDカードへ保存する * ・使用する加速度センサはLIS3DH * サンプリング周波数:50Hz、センサ値:12bit精度、 * LIS3DHのFIFOに16サンプル以上溜まったらデータを読み込む * ・センサデータが128サンプルを超えたらバイナリデータ(2byte/サンプル)として.SDAファイルへ保存する * ・センサデータ保存の32回毎に累積サンプル数とプログラム起動時からの経過時間を.TXTファイルへ保存する * (後の処理でサンプリング周波数のばらつきを補正する為に使用する) * ・加速度センサやSDカードへのアクセスでエラーが発生した場合は、正しく計測されていません *  その場合は、LEDが連続点滅(3回の「明・暗」と1回の「暗・暗」)して知らせます * ・電源SWのON後、LED点滅(10秒間「明」→ 3回の「暗・明」→ 「暗」)のタイミングが計測開始となる * → 計測開始の日時(2021年7月15日の10時12分14秒ならば「2021-07-15 10-12-14」)を別途メモしておく * → 後で咳の認識やグラフ表示の為に使用します * ・電源SWのOFFまたはSDカードを抜くことで計測が終了となる * * ・使用ファイル * NEXT_SDA.TXT:次に計測するときの.SDAデータと.TXTデータの保存ファイル名を保持 * 計測する際には事前にSDカードにこのファイルを作成しておく必要あり * データ形式 ACL_Z000.XXX ← 初期状態では最低限この行は必須 * ACL_Z001.XXX ← 計測する毎に行が増えていく * : * : * * ACL_Z***.SDA(センサデータ本体、SDA = Sensor Data Aruduinoの頭文字) * データ形式 2byteの整数データ列([下位byte][上位byte]の連続) * * ACL_Z***.TXT(サンプルNo.と経過時間(msec)の対応リスト、およそ4096サンプル毎に記録) * データ形式 S_No = 0 Timer(msec) = 13058 ← 最初の行は測定開始時の時間を表す * S_No = 4096 Timer(msec) = 92695 ← 4096番目のサンプルの時間は92695msec * : : : : * : : : : * */ #include <Wire.h> #include <SPI.h> #include <SD.h> // 定数の定義(ここから)************************************************** #define DEBUG 0 // デバッグ用シリアルプリント出力のオン/オフ 0:OFF 1:ON #define LED_ON_TIME 100 // ダイアグ表示時のLEDの「明」時間(msec) #define LED_OFF_TIME 900 // ダイアグ表示時のLEDの「暗」時間(msec) #define ACC_Z_BUF_SIZE 128 // Acc_z用バッファサイズの最小値(int型128個 = 256bytes) // LIS3DH の設定(ここから)********** #define ACL_SEN_ADRS 0x18 // 加速度センサのI2Cアドレスの設定 #define CTRL_REG1 0x20 // CTRL_REG1のレジスタアドレス設定 #define CTRL_REG1_DEF_VALUE 0x07 // CTRL_REG1のデフォルト値 #define CTRL_REG1_VALUE 0x47 // CTRL_REG1の設定値(50Hz、XYZ軸全て有効) #define CTRL_REG4 0x23 // CTRL_REG4のレジスタアドレス設定 #define CTRL_REG4_VALUE 0x08 // CTRL_REG4の設定値(+/- 2g、高解像度有効) #define CTRL_REG5 0x24 // CTRL_REG5のレジスタアドレス設定 #define CTRL_REG5_FIFO_EN 0x40 // CTRL_REG5の設定値(FIFOを有効にする) #define CTRL_REG5_FIFO_DEN 0x00 // CTRL_REG5の設定値(FIFOを無効にする) #define FIFO_CTRL_REG 0x2E // FIFO_CTRL_REGのレジスタアドレス設定 #define FIFO_CTRL_REG_BYPASS 0x0F // FIFO_CTRL_REGの設定値(Bypass mode、ウォーターマークレベル = 15) #define FIFO_CTRL_REG_FIFO 0x4F // FIFO_CTRL_REGの設定値(FIFO mode、ウォーターマークレベル = 15) #define FIFO_SRC_REG 0x2F // FIFO_SRC_REGのレジスタアドレス設定 #define WTM 0x80 // FIFO_SRC_REG内のWTM bit の位置(bit[7]) #define FSS 0x1F // FIFO_SRC_REG内のFSS bit の位置(bit[4:0]) #define OUT_X_L 0x28 // X軸加速度データ下位8bit読込みレジスタアドレス #define OUT_X_L_B_READ 0xA8 // X軸加速度データBURST読込みレジスタアドレス #define OUT_Y_L 0x2A // Y軸加速度データ下位8bit読込みレジスタアドレス #define OUT_Y_L_B_READ 0xAA // Y軸加速度データBURST読込みレジスタアドレス #define OUT_Z_L 0x2C // Z軸加速度データ下位8bit読込みレジスタアドレス #define OUT_Z_L_B_READ 0xAC // Z軸加速度データBURST読込みレジスタアドレス #define OUT_Z_H 0x2D // Z軸加速度データ上位8bit読込みレジスタアドレス // ********** LIS3DH の設定(ここまで) // ************************************************** 定数の定義(ここまで) // データ収集用外部変数定義(ここから) **************************************** const int Psi_chipSelect = 7; // PSI(SDリーダーライターとのインターフェース)のチップセレクト。Seeduino XIAO の場合 byte Status_reg, Acl_sen_zl, Acl_sen_zh; char Next_f_name[15] = "ACL_Z000.XXX"; // 次に作成する計測値データ保存ファイル名 char F_no[5] = "000"; // 計測値データ保存ファイル名の番号部の生成用バッファ char F_name[15] = "XXXXXXXX.XXX"; // 計測値データを保存するファイル名 int16_t Acl_z[ACC_Z_BUF_SIZE + 32]; // Z軸加速度データ読込み用バッファ // +32は他の処理で読込み遅れが生じた時に備え、FIFOバッファ分の余裕を見込むため // **************************************** データ収集用外部変数定義(ここまで) // ダイアグ用外部変数(ここから) **************************************** byte Err_code = 0; // ダイアグ用エラーコード格納メモリ確保と0で初期化、0:エラー無し、1:エラー有り byte Err_code_pattern[4][2] = {{1,0},{1,0},{1,0},{0,0}}; // ダイアグのエラー発生時のLED出力パターン byte Tflag; // LED出力区間切替りタイミングイベント発生フラグ(1:発生、0:途中) byte Led_out_period = 0x00; // 現在の出力区間 bit3:区間の有無(0:無区間 1:区間有無区間とはプログラム起動直後を意味する)、 // bit0:区間(1:100msec区間 0:900msec区間) int16_t Led_out_count; // エラーコードパターンの現在の出力位置 uint32_t Led_on_off_timing; // LED出力区間切換えタイマー用時間情報 // **************************************** ダイアグ用外部変数(ここまで) /* diag_led_out():ダイアグLED出力処理 * 処理概要 * エラー発生時のLED出力は(100msec+900msec)を周期とし、 * 100msec区間に"明"を900msec区間に"暗"を設定したパターンを3回、 * 100msecと900msecの両方の区間に"暗"を設定したパターンを1回出力する。 * Tflag=1の場合(LED出力区間切換えタイミングイベントが発生)に以下の処理をする。 * 何らかのエラーが発生している場合  * 次が100msec区間の時 → Err_code_pattern[Led_out_count][0]が示す位置の値に応じてLED出力開始する * 次が900msec区間の時 → Err_code_pattern[Led_out_count][1]が示す位置の値に応じてLED出力開始する * Led_out_countは0~3の間でカウントアップ * 最後にTflagをクリアする * * 使用する外部変数 * Tflag (R/W) :LED出力区間切換えタイミングイベント発生フラグ(1:発生、0:途中)  * diag_timing()で 1 にセットし、ここで 0 にする * Err_code (R) :ダイアグ用エラーコード格納メモリ確保と0で初期化、0:エラー無し、1:エラー有り * Err_code_pattern (R) :ダイアグのエラー発生時のLED出力パターン 1:明 0:暗 * Led_out_count (R/W) :エラーコードパターンの現在の出力位置 * Led_out_period (R) :現在の出力区間  * bit3:区間の有無(0:無区間 1:区間有、無区間とはプログラム起動直後を意味する)、 * bit0:区間(0:900msec区間 1:100msec区間) * * 使用する内部変数 * 無し */ void diag_led_out() { if (Err_code == 0) { // 出力すべきエラーが無い時 Tflag = 0; return; } // 出力すべきエラーがある時 if (Led_out_period == 0x11) { // 次が100msec区間の時 digitalWrite(LED_BUILTIN,Err_code_pattern[Led_out_count][0] ^ 0x01); // LED出力が 0:明 1:暗 なので符号を反転する } else { // 次が900msec区間の時 digitalWrite(LED_BUILTIN,Err_code_pattern[Led_out_count][1] ^ 0x01); // LED出力が 0:明 1:暗 なので符号を反転する Led_out_count++; if ( Led_out_count == 4 ){ Led_out_count = 0; } } Tflag = 0; // Tflagクリア } /* diag_timing():ダイアグLED出力区間切換えタイミングイベント発生の判定と、次に出す区間を指定する * 処理概要 * エラー発生時のLED出力パターンの切換えタイミングを監視し、切換えタイミングになったら以下の処理をする。 * ・次のLED出力区間の設定(100msec OR 900msec) * ・次の切換えタイミングの時刻の設定 * ・LED出力区間切換えタイミングイベント発生のフラグの設定 * * 使用する外部変数 * Led_on_off_timing(R/W) :次にLEDの出力を切換える時刻 * Tflag(R/W) :LED出力区間切換えタイミングイベント発生フラグ(1:発生、0:途中)  * ここで 1 にセットし、diag_led_out()で 0 にする * Led_out_period(R/W):現在の出力区間 * bit3:区間の有無(0:無区間 1:区間有、無区間とはプログラム起動直後を意味する) * bit0:区間(1:100msec区間 0:900msec区間) * 使用する内部変数 * 無し * * 戻り値 * Tflag * 0:切替りタイミングではない * 1:切替りタイミング発生 */ byte diag_timing() { if (Led_out_period == 0x00) { // プログラム起動直後(現在の出力が無区間)の時 → スタート処理 Tflag = 0; Led_on_off_timing = millis() + LED_OFF_TIME; Led_out_period = 0x10; } else if (Led_out_period == 0x10) { // 現在が900msec区間の時 → 100msec区間への切替え判定と処理 if (Led_on_off_timing <= millis()) { Tflag = 1; Led_on_off_timing += LED_ON_TIME; Led_out_period = 0x11; // 100msec区間 } } else if (Led_out_period == 0x11) { // 現在が100msec区間の時 → 900msec区間への切替え判定と処理 if (Led_on_off_timing <= millis()) { Tflag = 1; Led_on_off_timing += LED_OFF_TIME; Led_out_period = 0x10; // 900msec区間 } } return(Tflag); } /* i2c_write_byte():LIS3DHのレジスタへ1Byteデータを書き込む * 処理概要 * 以下の処理を順次実行する * ・Wire.beginTransmission(i2c_adrs):指定アドレスのスレーブへI2C通信の送信処理開始 * ・Wire.write(s_reg):レジスタアドレスのキューイング * ・Wire.write(s_data):レジスタへセットするデータのキューイング * ・Wire.endTransmission(false):キューイングしたデータの送信を実行する * 各処理でエラーが発生したら、そのコードを戻り値として返す * * 使用する引数 * i2c_adrs(R) :LIS3DHのスレーブアドレス * s_reg(R) :出力先のレジスタアドレス  * s_data(R) :出力データ * * 使用する内部変数 * i2c_ret(R/W) :Wireライブラリ実行時の戻り値 * err_code(R/W) :エラー発生個所を示すコード、この値がこの関数の戻り値となる * * 戻り値 * err_code * 0 :エラー無し * 0以外 :エラー有り */ byte i2c_write_byte(uint8_t i2c_adrs, uint8_t s_reg, uint8_t s_data){ byte i2c_ret, err_code; err_code = 0; Wire.beginTransmission(i2c_adrs); // 指定アドレスのスレーブへI2C通信の送信処理開始。 while( Wire.write(s_reg) != 1 ){ // レジスタアドレスのキューイング。 err_code = err_code | 0x01; } while( Wire.write(s_data) != 1 ){ // レジスタへセットするデータのキューイング。 err_code = err_code | 0x02; } i2c_ret = 5; while ( i2c_ret != 0){ i2c_ret = Wire.endTransmission(false); // キューイングしたデータの送信を実行する。 if(i2c_ret != 0 ){ err_code = err_code | 0x04; } } return (err_code); } /* i2c_read_bytes():LIS3DHのレジスタから指定Byteのデータを読み込む * 処理概要 * 以下の処理を順次実行する * ・Wire.beginTransmission(i2c_adrs):指定アドレスのスレーブへI2C通信の送信処理開始 * ・Wire.write(s_reg):レジスタアドレスのキューイング * ・Wire.endTransmission(false):キューイングしたデータの送信を実行する * ・Wire.requestFrom(i2c_adrs, s_bytes, false):i2cのレジスタからs_bytesバイトのデータ読み出しを宣言 * ・Wire.read():指定Byte数のデータを読み込む * 各処理でエラーが発生したら、そのコードを戻り値として返す * * 使用する引数 * i2c_adrs(R) :LIS3DHのスレーブアドレス * s_reg(R) :読込み先のレジスタアドレス  * *s_data(W) :読み込んだデータを格納する配列へのポインタ * s_bytes(R) :読み込むデータのByte数 * * 使用する主な内部変数 * i2c_ret(R/W) :Wireライブラリ実行時の戻り値 * err_code(R/W) :エラー発生個所を示すコード、この値がこの関数の戻り値となる * * 戻り値 * err_code * 0 :エラー無し * 0以外 :エラー有り */ byte i2c_read_bytes(uint8_t i2c_adrs, uint8_t s_reg, uint8_t *s_data, uint8_t s_bytes){ byte i2c_ret, err_code; int i, j; err_code = 0; Wire.beginTransmission(i2c_adrs); // I2C通信の開始。スレーブ側のアドレスの定義する。 while( Wire.write(s_reg) != 1 ){ // レジスタアドレスのキューイング。 err_code = err_code | 0x01; } i2c_ret = 5; while ( i2c_ret != 0){ i2c_ret = Wire.endTransmission(false); // キューイングしたデータの送信を実行する。 if(i2c_ret != 0 ){ err_code = err_code | 0x02; } } i2c_ret = 0xff; while ( i2c_ret != s_bytes){ i2c_ret = Wire.requestFrom(i2c_adrs, s_bytes, false); // 加速度センサからs_bytesバイトのデータ読み出しを宣言。 if(i2c_ret != s_bytes ){ err_code = err_code | 0x04; } } i = 0; j = s_bytes; while ( j != 0 ){ while (Wire.available() != j ); s_data[i] = Wire.read(); // 加速度データの読み出し。 i++; j--; } return (err_code); } /* read_fifo_status():LIS3DHのFIFO STATUSレジスタのデータを読み込む * 処理概要 * 以下の処理を順次実行する * ・i2c_read_bytes()関数を実行する * ・エラーが発生したらErr_codeに1をセット * fifo_statusを戻り値として返す * * 使用する引数 * 無し * * 使用する外部変数 * Err_code (R) :ダイアグ用エラーコード格納メモリ確保と0で初期化、0:エラー無し、1:エラー有り * * 使用する内部変数 * i2c_ret(R/W) :Wireライブラリ実行時の戻り値 * fifo_status(R/W) :FIFO STATUSレジスタから読み込んだ値、この値がこの関数の戻り値となる * * 戻り値 * fifo_status */ byte read_fifo_status() { byte fifo_status, i2c_ret; // FIFO_SRC_REGを読み込む i2c_ret = i2c_read_bytes(ACL_SEN_ADRS, FIFO_SRC_REG, &fifo_status, 1); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } return ( fifo_status ); } /* read_fifo_data():LIS3DHのFIFOからデータを読み込む * 処理概要 * 以下の処理を順次実行する * ・i2c_read_bytes()関数を実行する * ・エラーが発生したらErr_codeに1をセット * fifo_statusを戻り値として返す * * 使用する引数 * *fifo_data(W) :FIFOからの読込みデータを格納するメモリへのポインタ * s_bytes(R) :FIFOから読み込むデータ数(Byte数) * * 使用する外部変数 * Err_code (R) :ダイアグ用エラーコード格納メモリ確保と0で初期化、0:エラー無し、1:エラー有り * * 使用する内部変数 * i2c_ret(R/W) :Wireライブラリ実行時の戻り値 * * 戻り値 * 無し */ void read_fifo_data(uint8_t *fifo_data, uint8_t s_bytes) { byte i2c_ret; // FIFOから指定バイト数のデータを読み込む i2c_ret = i2c_read_bytes(ACL_SEN_ADRS, OUT_X_L_B_READ, fifo_data, s_bytes); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } } /* setup():セットアップ処理 * 処理概要 * ビルトインLEDのアクティブ化と首都力を「暗」にセット → 加速度計測処理の開始タイミング表示とエラー発生時のダイアグ表示用 * PSI通信のCSピン位置の設定 * I2C通信のバスマスターとして初期化 * LIS3DHの各種レジスタの設定 * サンプリング周波数:50Hz、XYZ軸全て有効、測定レンジ:±2G、高解像度有効 * SD card 有りの判定とライブラリの初期化 * NEXT_SDA.TXTから測定データを保存するファイル名を読み込み、次に保存するファイル名も一番下の行に追加する * (ファイル名はACL_Z000.XXXからACL_Z999.XXXまで1ずつ増加し、ACL_Z999.XXXの次はACL_Z000.XXXに戻る) * 計測開始時刻確認のタイミングの為のLED点滅(10秒「明」の後、3回の「暗・明」) * 加速度センサデータ読み込みとファイルへの保存関数(acl_z_record())の起動 * * 使用する外部変数 * Psi_chipSelect (R) :PSI(SDリーダーライターとのインターフェース)のチップセレクト * Err_code (R) :ダイアグ用エラーコード格納メモリ確保と0で初期化、0:エラー無し、1:エラー有り * F_name[] (R/W) :計測値データを保存するファイル名 * F_no[] (R/W) :計測値データ保存ファイル名の番号部の生成用バッファ * Next_f_name[] (R/W):次に作成する計測値データ保存ファイル名 * * 使用する主な内部変数 * i2c_ret (R/W) :Wireライブラリ実行時の戻り値 * tmp_char (R/W) :NEXT_SDA.TXTからデータを読み出す為のバッファ * * 戻り値 * 無し */ void setup() { byte i2c_ret; // Wireライブラリ実行時の戻り値 char tmp_char; // NEXT_SDA.TXTからデータを読み出す為のバッファ int i, f_no_i; pinMode(LED_BUILTIN, OUTPUT); // ビルトインLEDをアクティブ化 digitalWrite(LED_BUILTIN,1); // ビルトインLEDを「暗」にセット pinMode(Psi_chipSelect,OUTPUT); // PSI通信のCSピンの位置を設定する Wire.begin(); // I2C通信のバスマスターとして初期化 // LIS3DHのCTRL_REG1の設定。(50Hz、XYZ軸全て有効) i2c_ret = i2c_write_byte(ACL_SEN_ADRS, CTRL_REG1, CTRL_REG1_VALUE); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } // LIS3DHのCTRL_REG4の設定。(+/- 2g、高解像度有効) i2c_ret = i2c_write_byte(ACL_SEN_ADRS, CTRL_REG4, CTRL_REG4_VALUE); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } /* ----- SD card 有りの判定とライブラリの初期化------ */ #if DEBUG == 1 Serial.print("Initializing SD card..."); #endif //see if the card is present and can be initialized: if (!SD.begin(Psi_chipSelect)) { Err_code = 1; // エラー発生 } // NEXT_SDA.TXTから測定データを保存するファイル名(一番下の行)を読み込む File dataFile = SD.open("NEXT_SDA.TXT", FILE_WRITE); if( dataFile != false ) { while(!dataFile.seek(0)); // WRITEモードでオープンしているので、読み込み位置を先頭に戻す i = 0; while( true ){ // NEXT_SDA.TXT内の最後の行を読み込む(そこまでの行は読み飛ばす) tmp_char = char(dataFile.read()); if( tmp_char == 0xFF ){ break; } else { F_name[i] = tmp_char; i++; if( i == 14 ){ i = 0; } } } } else { Err_code = 1; // エラー発生 } if( F_name[12] != 0x0d ){ // NEXT_SDA.TXTの最終行にリターンが入っていない時は追加する dataFile.println(); } // 次回の計測の時に使うデータファイルの名前を生成して一番下の行へ追加する F_no[0] = F_name[5]; F_no[1] = F_name[6]; F_no[2] = F_name[7]; f_no_i = atoi(F_no); // 文字列を数字に変換 f_no_i++; if( 1000 <= f_no_i ){ f_no_i = 0; } sprintf(F_no, "%03d", f_no_i); // 数字を3桁の文字列に変換('0'で埋める) Next_f_name[5] = F_no[0]; Next_f_name[6] = F_no[1]; Next_f_name[7] = F_no[2]; dataFile.println(Next_f_name); dataFile.close(); // 計測開始時刻確認のタイミングの為のLED点滅(10秒「明」の後、3回の「暗・明」) digitalWrite(LED_BUILTIN,0); // ビルトインLEDを「明」にセット delay(10000); // 10秒間「明」 for( i = 0; i < 3; i++ ){ digitalWrite(LED_BUILTIN,1); // ビルトインLEDを「暗」にセット delay(300); // 0.3秒間「暗」 digitalWrite(LED_BUILTIN,0); // ビルトインLEDを「明」にセット delay(700); // 0.7秒間「明」 } digitalWrite(LED_BUILTIN,1); // ビルトインLEDを「暗」にセット acl_z_record(); } /* acl_z_record():Z軸加速度データの読込みとSDカードへの保存 * 処理概要 * 加速度データ用ファイル名(.SDA)、時間計測用ファイル名(.TXT)を作成する * LIS3DHの各種レジスタの設定 * FIFOの有効化、BYPASS modeの起動(FIFOバッファのクリアする)、FIFO modeの起動 * FIFOからZ軸加速度データを読み込み、Z軸加速度データバッファ(Acl_z[])へ格納する * Acl_z[]がACC_Z_BUF_SIZEで指定するサイズを超えたら.SDAファイルへバイナリモードで追記する * (書込みサイズは書込み平均時間が小さい256Bytesを狙う。約2.56秒毎に書込み) * Acl_z[]がACC_Z_BUF_SIZEで指定するサイズを超えていない間は、ダイアグ表示処理を行う * .SDAファイルへの書込みの32回毎に累積サンプル数とプログラム起動時からの経過時間を.TXTファイルへ保存する * .TXTファイルに保存している間のみLEDを「明」とする * (約82秒毎に光ります。これにより、プログラムが正常動作している事を確認できます。) * * 使用する外部変数 * F_name[] (R/W) :計測値データを保存するファイル名 * Err_code (R) :ダイアグ用エラーコード格納メモリ確保と0で初期化、0:エラー無し、1:エラー有り * Acl_z[] (R/W) :Z軸加速度データ読込み用バッファ * * 使用する主な内部変数 * f_name_sda[] (R/W) :Z軸加速度データを保存するファイル名 * f_name_time (R/W) :累積サンプル数とその時の時刻のデータを保存するファイル名 * fifo_count (R/W) :FIFOに溜まっているデータ数 * fifo_status (R/W) :FIFO_STATUS_REGの状態 * fifo_data[] (R/W) :FIFOから読み出したデータのバッファ * i2c_retWire (R/W) :Wireライブラリ実行時の戻り値 * tmp_acl_z (R/W) :Z軸加速度データ読込み用一時バッファ * s_read_count (R/W) :Acl_z[]バッファへ読込んだデータ数 * s_read_count_sum (R/W):Acl_z[]バッファへ読込んだデータ数の累積値 * timer_millis (R/W) :プログラム起動時からの経過時間(msec)の一時保管 * f_write_count (R/W) :.SDAファイルへの書込み回数 * * 戻り値 * 無し */ void acl_z_record() { char f_name_sda[] = "XXXXXXXX.SDA"; // .SDAファイルのオープン用のファイル名を保持する char f_name_time[] = "XXXXXXXX.TXT"; // 時間計測用データファイル(.TXT)のオープン用のファイル名を保持する //boolean tflag = false; byte fifo_count, fifo_status, fifo_data[192], i2c_ret; int16_t i, j, s_read_count, tmp_acl_z; uint32_t s_read_count_sum = 0, timer_millis, f_write_count = 0; /* F_nameの内、8bytes(.拡張子より前の部分)をf_name_sda、f_name_timeへコピー。*/ for ( i=0; i < 8; i++ ){ f_name_sda[i] = F_name[i]; // 加速度データ用ファイル名(.SDA) f_name_time[i] = F_name[i]; // 時間計測用ファイル名(.TXT) } // LIS3DHのCTRL_REG5の設定。(FIFOの有効化) i2c_ret = i2c_write_byte(ACL_SEN_ADRS, CTRL_REG5, CTRL_REG5_FIFO_EN); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } // LIS3DHのFIFO_CTRL_REGの設定。(BYPASS modeの起動、これによりFIFOバッファをクリアする) i2c_ret = i2c_write_byte(ACL_SEN_ADRS, FIFO_CTRL_REG, FIFO_CTRL_REG_BYPASS); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } // LIS3DHのFIFO_CTRL_REGの設定。(FIFO modeの起動) i2c_ret = i2c_write_byte(ACL_SEN_ADRS, FIFO_CTRL_REG, FIFO_CTRL_REG_FIFO); if (i2c_ret != 0) { Err_code = 1; // エラー発生 } timer_millis = millis(); // ACC_Z_BUF_SIZE(=128Bytes)サンプルの加速度データの読み込み(ここから)******************** while( true ) { i = 0; while( i < ACC_Z_BUF_SIZE ) {// ACC_Z_BUF_SIZE(=128Bytes)サンプルの加速度データの読み込み fifo_status = read_fifo_status(); // FIFO_SRC_REGの読込み fifo_count = fifo_status & FSS; if( (fifo_status & WTM) != 0 ) { // FIFO_WTMビットが「1」の時(WTMオーバーフローあり) read_fifo_data(fifo_data, (uint8_t)(fifo_count * 6)); for(j = 0; j < fifo_count; j++ ){ tmp_acl_z = (int16_t)fifo_data[j*6+5]; // Z軸加速度の上位バイトの読込み(1Byte変数→2Byte変数へ変換) tmp_acl_z = tmp_acl_z << 8; // 8Bit左シフト tmp_acl_z = tmp_acl_z | (int16_t)fifo_data[j*6+4]; // Z軸加速度の下位バイトを読込み、上位バイトと結合 Acl_z[i] = tmp_acl_z >> 4; // 4Bit右シフト(センサー値を12Bitで扱います) i++; } } else { // FIFOからの読込みタイミング以外の時 if (diag_timing() == 1) { // ダイアグ出力切換えタイミングの場合 diag_led_out(); // ダイアグ出力処理 } } } // ******************** ACC_Z_BUF_SIZE(=128Bytes)サンプルの加速度データの読み込み(ここまで) s_read_count = i; // s_read_count分の加速度データをSDカードに書き込む File dataFile = SD.open(f_name_sda, FILE_WRITE); // .SDAファイルをオープン if( dataFile == false ){ Err_code = 1; } dataFile.write((byte*)Acl_z, (s_read_count * 2)); dataFile.close(); // (s_read_count x 32)回分のセンサ読込み時の時刻データをSDに書き込む(ここから)********** // (サンプリング周波数が50Hzならば設計上は約81.92秒毎) f_write_count++; s_read_count_sum += s_read_count; if( f_write_count == 1 ){ // 計測開始時の時刻を書き込む dataFile = SD.open(f_name_time, FILE_WRITE); // .TXTファイルをオープン if( dataFile == false ){ Err_code = 1; } dataFile.print("S_No = 0"); dataFile.print(" Timer(msec) = "); dataFile.println(timer_millis); dataFile.close(); } if( f_write_count % 32 == 0 ){ // (s_read_count x 32)回毎に時刻を書き込む timer_millis = millis(); digitalWrite(LED_BUILTIN,0); // 書込み開始タイミングLEDを「明」にセット dataFile = SD.open(f_name_time, FILE_WRITE); if( dataFile == false ){ Err_code = 1; } dataFile.print("S_No = "); dataFile.print(s_read_count_sum); dataFile.print(" Timer(msec) = "); dataFile.println(timer_millis); dataFile.close(); digitalWrite(LED_BUILTIN,1); // 書込み開始タイミングLEDを「暗」にセット // ********** (s_read_count x 32)回分のセンサ読込み時の時刻データをSDに書き込む(ここまで) // SDカードへのアクセス可否確認(ここから)********** //(FILE_WRITEでSD.open実行時にエラーが検出できない場合があるため、ここで確認を追加する) dataFile = SD.open(f_name_time, FILE_READ); if( dataFile == false ){ Err_code = 1; } else{ if( char(dataFile.read()) == 0xFF ){ Err_code = 1; } } dataFile.close(); // ********** SDカードへのアクセス可否確認(ここまで) } } } void loop() { // 空ループ } 以上が「咳カウンター」システムの「機能紹介と記録モジュール編」です。 次回の「PC側ソフトウェア編」では加速度データをグラフ化するツールと、 咳を認識して時間帯毎のヒストグラムを作成するツールをPythonにて作成した 記事を紹介します。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

[Python]等差数列×等比数列の解を求めてみた[備忘録]

はじめに Pythonでコーディングを行ったので、備忘録として残します。 Python環境下にて、是非お使いください。 コーディング #sympy内の関数を使用 from sympy import * n=Symbol("n") a,b,c = map(int,input("a,b,c=? ? ?").split()) while c==0: print("Input a,b,c except c==0") a,b,c = map(int,input("a,b,c=? ? ?").split()) Prog = sequence((a*n+b)*c**n, (n, 0, 20)) #以下、出力(例:a,b,c=1 1 2) print([int(N) for N in Prog]) print(Prog[3]) 出力結果 [1, 4, 12, 32, 80, 192, 448, 1024, 2304, 5120, 11264, 24576, 53248, 114688, 245760, 524288, 1114112, 2359296, 4980736, 10485760, 22020096] 32 使い方 ▶a,b,cは半角スペースで区切り、入力してください。 ▶外部ライブラリにつき、予めインストールを要します。(以下、プロンプトでの入力例) conda install sympy #エラーが起これば、下のコマンドを実行してください。 pip install sympy 留意点 ▶c=0の場合、while文以下のループが発生します。(c≠0でお願いいたします。)
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【ImageJからPythonへ】napariの使い方 (1)

はじめに ImageJ/Fijiにはしばらくお世話になっていたのですが、scipyやscikit-image、機械学習関連などといったPythonの豊富なライブラリが利用できず、何か実装したかったらJavaで書くしかありません。これは、実験者と開発者の分断を助長する、あまりにもよろしくない事態です。加えて、ImageJとExcelを行き来せざるを得なかったり、ウィンドウが増えすぎたりと、使っていてかなり不便だと感じる場面が少なくないと思います。 この現状を打破するべく調査していて行きついたのがnapariという素晴らしいPythonのライブラリでした。これがあれば、すべての解析をPythonで統一できます。ぜひ知ってもらいたいので、複数回に分けてまじめに紹介していこうと思います。 目次 napariとは napariでできること インストール 基本操作 スクリプトからのレイヤーの追加 スクリプトからレイヤーへのアクセス napariとは Pythonで画像を表示する際、Matplotlibでもいいのですが、拡大縮小したり、スライダーを動かして動画として見たり、立体的にぐるぐるしたりしたいですよね。napariは napari is a fast, interactive, multi-dimensional image viewer for Python と謳っている通り、快適に多次元画像を可視化することができます。詳しくは GitHubのリポジトリ 公式ドキュメント をお読みください。ここでは重要なところに絞って使い方を説明していきます。 ※ Julia言語からでも使えます。GitHubのこちらに上がっています。JuliaImagesを使い慣れているJulianの方は参考までに。 ※ scikit-imageもviewerというサブモジュールを有していますが、バージョン0.20で消されるので他のを使ってねと言っており、その中でnapariが紹介されています(こちら)。 この記事はnapariのバージョン0.4.10現在に関するものです。ホットに開発が進んでいるので、これから新機能の追加や変更があるかと思います。 napariでできること ImageJとの比較みたいになりますが、主に次のような特徴が挙げられます。 1. 1つのウィンドウ内で、画像やラベル、長方形などのオブジェクトをレイヤーとして複数重ねることができる。 これでウィンドウが増えすぎることが防げるうえ、複数の画像を重ねながら比較する際に毎回"Merge Channel"とかやらずに済むので非常に便利です。 2. キーボードショートカット、マウスのクリック時・ドラッグ時の機能を非常に簡単に登録でき、レイヤーごとに設定することもできる。 例えば、画像の特定領域を切り出すのはGUI頼みなので、そこからscipy.optimizeでフィッティングにスムーズに繋げることが困難でしたが、"F"ボタンにフィッティングを行う関数を登録しておけば解決します。ソフトを使い分ける時代はもう終わりで、これからは必要に応じて素早くUIを作り変える時代です。 3. 手動で選択したレイヤーやオブジェクトの情報をコンソールから対話的に取得できる。 2.とも関連しますが、Python上での画像解析の弱点は、人間が目で見て判断する操作は自動化できないという点です。細胞の領域を取り出してくる、結晶の周期構造がx軸に平衡になるように回転させる、といった操作くらいは人間がやることだと思います。これを毎回matplotlibでプロットしてはおおよその値を入力するのでは効率が悪いですが、napariを通して行えばずっとスムーズになります。 4. なんならImageJを呼び出せる。 pyimagejというものをインストールするとPythonからImageJを起動し、データをやり取りできます(需要はあるのか?)。環境構築が一筋縄では行かなかったのですぐにやめました。 インストール 画像解析をするならcondaを使いましょう。 conda install napari -c conda-forge PyQtなどが最新のもの (>=5.12) になっていないと思うように動作しないかもしれないので、conda-forgeチャネルの最新のものにしておきましょう。 conda install pyqt -c conda-forge 基本操作 Viewerを立ち上げて画像を表示 モジュールをインポートします。 import napari import numpy as np import skimage napari.Viewerクラスが、napariのウィンドウ本体になります。ウィンドウ内にあるものはすべてここからアクセスします。次のようにインスタンスviewerを作成します。 viewer = napari.Viewer() napari.run() # Jupyterであれば、これは不要 インスタンスが作成された時点で、図のようにnapariのウィンドウが開きます。 (初期設定によって多少レイアウトが変わる可能性あり) 続いて画像を追加します。napari.Viewerのメソッドadd_imageを使います。 img = skimage.data.camera() # 画像を取得 viewer.add_image(img) [Out] <Image layer 'img' at 0x...> これでウィンドウに画像が表示されました!ドラッグで平行移動、スクロールで拡大縮小できます。 ちなみに、ウィンドウを立ち上げるだけのことはほとんどないので、同時に画像を送る関数view_imageも用意されています。 viewer = napari.view_image(img) # 新しいウィンドウを立ち上げてそこに画像を送る ウィンドウ上での簡単な操作 napariはレイヤーを追加していく形をとります。次のGIFのように、"layer list"ウィジェット内にあるボタンから順にPointsレイヤー、Shapesレイヤー、Labelsレイヤーを追加できます。 各レイヤーは選択中に、左上の"layer control"ウィジェットから固有の操作ができます。例えば、 Points ... matplotlibの操作を手動で行うイメージ 新しい点の追加/削除(3次元的にもできる!) 点の移動 点の形の変更 Shapes ... PowerPointのイメージ 長方形、折れ線、楕円などのオブジェクトを追加/削除 オブジェクトの選択、拡大/縮小、回転 オブジェクト頂点の追加/削除 Labels ... ペイントのイメージ 色を塗ったり消したりする(異なる色は異なる自然数のラベルに対応する) 塗りつぶし カラーピック 3次元描画(!) などができます。ショートカットの整備が不十分なのと、一部Ctrl+Zが利かないのがこれからの課題ですが、欲しい機能はだいたい揃っている感じがしますね。 他にも、上部のメニューバーから座標軸やスケールバーを表示したり、スクリーンショットと撮ったりできます。 画像の3D表示 napariが得意とする3D表示をやってみます。今回はメニューバーFile > Open Sample > scikit-image > Binary Blobs (3D)とすることでサンプル画像を取得します。scikit-imageのサンプル画像を全部ダウンロードしている場合は"Brain (3D)"や"Cells (3D+2Ch)"など、もっとマシな画像が使えるかと思います。 デフォルトでは2D表示なので、下部のスライダーを動かして断面をみる形になります。3D関係の操作は"layer list"の下のボタンから行います。 3Dに切り替えると、"Binary Blobs (3D)"だとほぼ真っ白になってしまいます(バイナリ画像なので)。これは、レンダリングが"mip (max intensity projection)"になっているためで、バイナリ画像の場合"iso (iso-surface)"に設定すると見やすくなります。ぐるぐる回しましょう。 プラグインの導入 napariは機能追加のしやすさが売りなので、当然プラグインシステムが既に確立しています。上部のメニューバーのPlugins >> Install/Uninstall Package(s)...を選択すると、次の図に示すように新しいウィンドウに現在手に入るプラグインや更新可能なプラグインなどが表示され、そのまま手動でpipインストールできます。 簡単にアニメーションを作れるnapari-animationは特におすすめです。 スクリプトからのレイヤーの追加 続いては、スクリプトからレイヤーを追加していく方法について詳しく説明します。すでに述べたようにadd_imageを用いることでとりあえず画像をnapariに送ることができますが、実際には異なる色でマージした状態で送ったり、画像ではなく点や長方形を追加したりしたいこともあります。napariには多彩なレイヤーが用意されているので、状況に応じて使い分けられるとよいですね。 より詳細な情報はこちらを参照。 → napariの公式ドキュメント 0. 各レイヤー共通のパラメータ レイヤーはすべてLayerクラスを継承しており、レイヤーの種類にかかわらず、共通したプロパティがあります。これらはキーワード引数も共通しているので、初めにまとめてしまいます。 詳細を表示 name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=1, blending='translucent', visible=True, multiscale=False name str ... レイヤーの名前。他と被る場合は"XXX [1]"のように後ろに番号がつく。 metadata dict ... レイヤーのメタデータ。補足・メモ用。 scale tuple of float ... 各軸のスケール。例えばscale=(0.5, 0.2)なら、y軸で50%、x軸で20%のスケールがかかり、縦長のピクセルになる(行/列はy/xの順に対応することに注意!)。 translate tuple of float ... 各軸の平行移動。例えばtranslate=(3,5)なら、下に3ピクセル、右に5ピクセル移動して画像が表示される。 rotate float, tuple of float ... 反時計回り、$(0, 0)$のピクセルを中心とした回転。ラジアンではなく度数。tuple of floatで与えられた場合は3次元回転に対応する。 shear 1-D array ... せん断(正方形をひし形につぶすような変換)。2D画像の場合、shear=[a]とするとx軸が$y=ax$方向に向いたように画像がつぶれる。 affine 2-D array ... アフィン変換行列。scale, translate, rotate, shearはアフィン変換行列を分かりやすく分解したものなので、それらを指定するのとaffineを指定するのは数学的には同義となる。 opacity float ... 非透明度、もしくはアルファチャネル。0で完全に透明、1で不透明になる。 blending str ... 他のレイヤーとどのようにブレンドするか。opaque (不透明)、translucent (半透明)、additive (加算的) のいずれか。要するに複数のレイヤーが重なるときに、奥側のレイヤーを隠すかどうかを指定する。しかしなぜかtranslucentで半透明にならない。 visible bool ... レイヤーを見える状態で追加するか否か。 multiscale bool ... マルチスケールで画像を表示するか否か。非常に大きな画像では拡大縮小の倍率変化が大きくなる。このとき、細部を見るためにズームインしているときは1ピクセル単位で表示する必要があるが、全体を見るためにズームアウトしているときは画像の表示のために全ピクセルを計算させる必要はない。このとき、add_image([img, img[::4, ::4]], multiscale=True)のように与えれば、ズームアウトしていくと途中で勝手に4ピクセルごとの表示に切り替わる。 なお、レイヤーの有効な利用例がGitHubリポジトリのこちらにまとまっています。 1. Image layer viewer.add_image(...)で配列を画像として追加します。 詳細を表示 add_image(data=None, *, channel_axis=None, rgb=None, colormap=None, contrast_limits=None, gamma=1, interpolation='nearest', rendering='mip', iso_threshold=0.5, attenuation=0.05, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=1, blending=None, visible=True, multiscale=None) data array ... 画像データ。普通ならnumpy.ndarrayだが、メモリに乗らない大きな画像をdaskで渡してもちゃんと表示される。daskの使い方に関しては以下を参照。 daskの公式ドキュメント 画像をdaskで読み込んでnapariで表示する方法 channel_axis int ... チャネル軸。この軸に沿って画像が分割され、異なる色、異なるレイヤーで画像が追加される。例えばTIFFファイルではchannel, y, xの順に軸が並んでいるので、channel_axis=0と指定することでチャネルごとに色付けされる。blendingも指定しなければ"additive"に切り替わる。 rgb bool ... RGB画像として追加するか。この場合は軸はPNG画像などの順番であるy, x, colorに従う必要がある。 colormap str ... カラーマップ、もしくはLUT (look up table)。画像のピクセル値と色を対応させるもので、napariが内部で使っているvispyで用意されているものを使うのが手っ取り早い。"gray", "plasma"など、matplotlibのカラーマップの代表的なものが用意されていると考えるとよい。 contrast_limits list ... 画像のコントラストの最小/最大。蛍光顕微鏡画像などでは外れ値があって周りが暗くなるので、np.percentileなどで外れ値を一部サチュレートさせたコントラストを使うことになる。 gamma float ... ガンマ値。ピクセル値とカラーマップを非線形に対応させるときに用いる。 interpolation str ... ピクセル間の補間の方法。画像をズームインし、スクリーンの画素よりも画像の1ピクセルが大きくなった時にどのように画像を表示するかを決める。例えば"nearest"はピクセルの正方形がそのまま拡大され、"bilinear"や"bicubic"ではなめらかに接続される。 rendering str ... 3D表示のときのレンダリング方法。 iso_threshold float ... rendering="iso"のときのパラメータ。 attenuation float ... rendering="attenuated_mip"のときのパラメータ。 2. Labels layer viewer.add_labels(...)で整数配列をラベルとして追加します。異なる整数は使える色を使いきるまで異なる色で表示され、0は背景となり透明になります。 詳細を表示 add_labels(data, *, num_colors=50, properties=None, color=None, seed=0.5, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=0.7, blending='translucent', visible=True, multiscale=None) data array ... ラベルデータ。scipy.ndimage.labelなどの出力をそのまま用いるのが基本となる。 SciPyの公式ドキュメント 例えば2D画像imgに対し from scipy import ndimage as ndi binary_input = img>100 structure = [[0,1,0],[1,1,1],[0,1,0]] labels, n_labels = ndi.label(binary_input, structure) で得られる配列labelsを渡すことで可視化できる。 num_colors int ... 用いる色の数。ラベル数に対して大きすぎる値を指定すると似た色が使われてしまうし、そもそも似た色を目で区別できるかという問題もあるので、基本的にデフォルトの50で問題ないと思われる。 properties dict or DataFrame ... (背景を含む)各ラベルに与えるプロパティ。ラベルの上にマウスを合わせると表示されるので、画像と対応させながら参照したい情報を載せるとよい。skimage.measure.regionprops_tableの結果と互換性があるが、こちらはデフォルトでは背景を含まないので注意 (→ scikit-imageの公式ドキュメント)。 color dict or array ... カスタムで色を指定する場合に用いる。 seed float ... ラベルの色をランダム生成する際のランダムシード。 3. Points layer viewer.add_points(...)で座標のリストをポイントとして追加します。 詳細を表示 add_points(data=None, *, ndim=None, properties=None, text=None, symbol='o', size=10, edge_width=1, edge_color='black', edge_color_cycle=None, edge_colormap='viridis', edge_contrast_limits=None, face_color='white', face_color_cycle=None, face_colormap='viridis', face_contrast_limits=None, n_dimensional=False, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=1, blending='translucent', visible=True, property_choices=None) data array ... 点の座標のリスト。D次元のN点を追加する場合、(N, D)という形状になる。 ndim int ... 空のレイヤーを作るときに、何次元にするかを指定する。 properties dict or DataFrame ... 各点のプロパティ。Labelsと同様だが、Pointsレイヤーでは後述するtextやface_colormapと組み合わせてることで、さらに高度な利用ができる。 text str or dict ... propertiesとの組み合わせで、format文字列の形式でポイントにテキストを付随させる。例えば、xy平面上でランダムに点を配置し、座標を$(x, y)$の書式のテキストで追加する場合は次のようになる。イメージとしては、各点に関してtext.format(**properties)のようなコードが走っていると考えるとよい。 arr = np.random.random((8,2))*100 # 8点 viewer.add_points(arr, text="({x:.1f}, {y:.1f})", properties={"x": arr[:,1], "y": arr[:,0]} ) symbol str ... 点の形。デフォルトの"o"は円に対応する。 size float ... 点の大きさ。 edge_width float ... 枠線の太さ。 edge_color/face_color str ... 枠線/塗りつぶしの色。 edge_color_cycle/face_color_cycle ndarray or list ... 枠線/塗りつぶしのカラーサイクル。複数の色を順番に使っていく場合に用いる。 edge_colormap/face_colormap str ... propertiesとの組み合わせで、各点に付随した値に応じて枠線/塗りつぶしの色を変えることができる。例えば点の座標に信頼度のようなものがあり、信頼度の高い点ほど赤色で明るくしたいとき、次のようなコードで実現できる。 arr = np.random.random((4, 2))*100 # 4点 properties = {"confidence": [0.9, 0.3, 0.6, 0.1]} # 各点の信頼度 viewer.add_points(arr, properties=properties, face_colormap="red", face_color="confidence" ) edge_contrast_limits/face_contrast_limits ... 枠線/塗りつぶしの色のコントラスト値。 n_dimensional bool ... Trueのとき、各点は立体的な大きさを持つようになる。例えば2Dで円だったものは、3Dでは球に見えるようになる。 property_choices dict ... propertiesで指定された各プロパティが取りうる値。 4. Shapes layer viewer.add_shapes(...)により長方形、折れ線、円などのオブジェクトを追加します。手動での編集、テキストの追加も容易なので、PowerPointのように画像のアノテーションに使ったり、ImageJのROIのように使ったりできます。 オプションはPointsレイヤーとほぼ同じです。キーワード引数の詳細はそちらを参照してください。 詳細を表示 add_shapes(data=None, *, ndim=None, properties=None, text=None, shape_type='rectangle', edge_width=1, edge_color='black', edge_color_cycle=None, edge_colormap='viridis', edge_contrast_limits=None, face_color='white', face_color_cycle=None, face_colormap='viridis', face_contrast_limits=None, z_index=0, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=0.7, blending='translucent', visible=True) data list or array ... 追加するオブジェクトの頂点の座標を指定する。listで複数同時に追加することができる。頂点の座標の数は次元数やオブジェクトの種類によって変わる。 shape_type str ... "line", "rectangle", "ellipse", "path", "polygon"のいずれか。 z_index int or list ... 前後の配置を決める。大きい値ほど前に出る。 例えば長方形を追加する場合 viewer.add_shapes([[2,2],[2,10],[8,10],[8,2]], shape_type="rectangle") のようになる。 5. Surface layer viewer.add_surface(...)により、3D表面を表示できます。これは3D物体の境界を定量的に強調したり、地図と標高から立体的な俯瞰図を作成したりするときに用います。 詳細を表示 add_surface(data, *, colormap='gray', contrast_limits=None, gamma=1, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=1, blending='translucent', shading='flat', visible=True) data tuple of arrays ... (頂点, 平面を形成する頂点の組, 値)のtupleで指定する (最後の値は任意)。現実的には3D画像から例えばskimage.measure.marching_cubeを用いて生成することになる。 scikit-image公式ドキュメント from skimage.meansure import marching_cubes verts, faces, _, values = marching_cubes(image) viewer.add_surface((verts, faces, values)) scikit-imageのサンプル画像である"brain"からSurfaceを生成するとこんな感じになります。 colormap, contrast_limits, gamma ... Imageと同様。 shading str ... "none", "flat", "smooth"のいずれか。影の付け方を指定する。 6. Track layer viewer.add_tracks(...)により、時間的に移動している点に残像を付けることで軌跡を表現できます。惑星の運動や、蛍光輝点の追跡結果の表示に非常に便利です。 詳細を表示 add_tracks(data, *, properties=None, graph=None, tail_width=2, tail_length=30, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=1, blending='additive', visible=True, colormap='turbo', color_by='track_id', colormaps_dict=None) data array 軌跡データ。D次元データの軌跡は(N, D+1)の形状をした配列で表される。2次元+時間の軌跡であれば、左の列から順にID、時間、y座標、x座標となる。IDが同じ点は一つの軌跡に属すると認識される。例えば時刻t=0-10で動く点とt=4-15で動く点が表現したければ、与えるべき配列は以下のような構成になる。 ID t y x 0 0 $y^0_0$ $x^0_0$ : : : : 0 10 $y^0_{10}$ $x^0_{10}$ 1 4 $y^1_4$ $x^1_4$ : : : : 1 15 $y^1_{15}$ $x^1_{15}$ properties dict or DataFrame ... 各点でのプロパティ。各軌跡ではないので、各プロパティの長さはdataの行数に一致する。時間経過に伴う軌跡の状態変化なども記述できるということ。 graph dict ... 軌跡の分離/融合をdictで{子:親}の形で表現する。例えばgraph = {1: 0, 2: 0, 3: [1, 2]}であれば、点0が点1,2に分離し、その後点1,2が点3に融合する軌跡が描ける。詳細はこちらのコードで確認できる。 tail_width, tail_length ... 軌跡の見せ方を変えるパラメータ。 colormap str ... 軌跡の色付けに用いるカラーマップ。 color_by str ... 何で色付けするか。デフォルトではIDごとに色付けされるが、propertiesを与えていれば好みのプロパティで色付けできる。 colormaps_dict dict ... 異なるプロパティで異なるカラーマップを使いたいときに指定する。他のカラーマップ関係の引数と異なり、なぜか辞書の値はnapari.utils.Colormapしか受け付けない。 7. Vector layer viewer.add_vectors(...)によりベクトル場を表示できます。add_imageと組み合わせてポテンシャル場・ベクトル場の時系列変化を可視化できるので、流体力学シミュレーションなどに向いているかと思います。 詳細を表示 add_vectors(data, *, properties=None, edge_width=1, edge_color='red', edge_color_cycle=None, edge_colormap='viridis', edge_contrast_limits=None, length=1, name=None, metadata=None, scale=None, translate=None, rotate=None, shear=None, affine=None, opacity=0.7, blending='translucent', visible=True) data array ベクトル場を表す配列。始点/終点を記述するパターン(scipy.ndimage.map_coordinatesに似た様式)と、各点のベクトル値を与える、ベクトル場らしい記述パターンがある。D次元空間でのdataの形状は 前者の場合 $(N, 2, D)$ であり、"N"はN本のベクトル、"2"は始点/終点に対応している。 後者の場合 $(N_1, ..., N_D, D)$ であり、例えばdata[0, 0, :]は2D平面上の点 $(0, 0)$ におけるベクトルの2成分に対応する。 properties dict or DataFrame ... 各ベクトルのプロパティ。これまで説明したレイヤーと同様。 edge_width/edge_color/edge_color_cycle/edge_colormap/edge_contrast_limits ... ベクトルの枠線の太さや色を決める。Pointsレイヤーと同様。 length float ... ベクトルの長さを何倍するか。ベクトルの値が極端に大きかったり小さかったりするときに指定する。 スクリプトからレイヤーへのアクセス napariのもう一つの強みは、手動で描いた点や図形の座標情報をnumpy.ndarrayとして受け取れる点です。これでPython上での画像解析の弱点を完全に克服できます。 レイヤーをスクリプトから指定して情報を得る napari.Viewerは多次元画像の可視化に必要ないくつもの変数からなっています。例えば axes (座標軸) camera (カメラの向きなど) layers (レイヤーリスト) scale_bar (スケールバー) などがあります。 ここではlayersからレイヤーの情報を得ることについて解説します。layersはLayerListクラスのオブジェクトで、Layerオブジェクトを格納したPythonのlistのようなものです。したがって、一番下のレイヤーが欲しければ layer = viewer.layers[0] でアクセスします。もしくは、"XXX"という名前のレイヤーであれば layer = viewer.layers["XXX"] で指定してもOKです。 ここで得られる変数layerはLayerオブジェクトなので、配列データ、名前、メタデータなどすべて含まれています。それぞれ次のようにアクセスします。 layer.data # 配列データ layer.name # 名前 layer.metadata # メタデータ 他にもレイヤーの種類によって様々な情報が得られますが、前章で登場したキーワード引数は、ほとんどすべてがそのままレイヤーオブジェクトの属性になっているので、難しくないと思います。 マウスで選択したレイヤーの情報を得る 毎回レイヤーの順番や個数は変わりうるので、できればマウスでクリックしてレイヤーを指定したいです。例えば以下のように、サンプル画像にフィルタやラベル付けをし、顔の部分を手動で囲った状況で、長方形の座標を得たいとしましょう。レイヤーリストを見ると"Shapes"が選択されています。 これはLayerListクラスのselectionプロパティで得ることができます。 viewer.layers.selection [Out] Selection({<Shapes layer 'Shapes' at 0x...>}) 見ての通り選択されているレイヤーがPythonのsetのようなものに格納されているので、次のようにアクセスする必要があります。 viewer.layers.selection.pop() # OK。pop()なので選択は解除される。 list(viewer.layers.selection)[0] # OK。もとのsetは変わっていないので選択は解除されない。 viewer.layers.selection[0] # Error! これで、コード自体は変えずに、状況に応じて手動で選択できるようになりました。 layer = viewer.layers.selection.pop() layer.data[0] # 0番目の図形の座標 [Out] array([[ 60.44640515, 159.68169239], [ 60.44640515, 274.63154028], [203.29670156, 274.63154028], [203.29670156, 159.68169239]]) 終わりに 読んでいただきありがとうございます。napariの多彩な機能は1回では紹介しきれないので、また次回に回します。予定としては よく使う自作関数をキーボードショートカットとして登録する マウスの動きを認識させてViewerでライブで解析する ウィジェットを追加する などを書いていこうかと思います。他にもご要望があればコメント欄にどうぞ!
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Atcoder ABC212 Python C問題振り返り

前書き C問題でACできなかったので今回も振り返っていきます。 Atcoder Factsという問題ごとの正解率をレート帯ごとに出してくれるサイトを見たところ、200-399のレートにいる人達の正解率が63.85%だったので解けるようにしておきたいですね... 自分の提出(TLE) ソートするところまでは良かったのですが、その後でどうすればよいか全く浮かびませんでした。 N, M = map(int,input().split()) A_2 = list(map(int,input().split())) B_2 = list(map(int,input().split())) A = sorted(A_2) B = sorted(B_2) ans = 0 if A[-1] <= B[0]: ans = B[0] - A[-1] print(ans) exit() elif B[-1] <= A[0]: ans = A[0] - B[-1] print(ans) exit() flg = [] Ans = [] for i in range(N): a = abs(A[i] - B[0]) b = abs(A[i] - B[-1]) if a > b: for j in range(-1,-M//2,-1): flg.append(abs(A[i] - B[j])) Ans.append(min(flg)) flg = [] elif a <= b: for j in range(0,M//2): flg.append(abs(A[i] - B[j])) Ans.append(min(flg)) flg = [] ans = min(Ans) print(ans) 修正後の提出(AC確認済み) Atcoderで解説動画を出してくださっているsnukeさんの解説を聞くことで実装できました。 snukeさんの解説は図を書いて視覚的に説明して下さるのでとても分かりやすいです。 (いつもありがとうございます...) import math N, M = map(int,input().split()) A_2 = list(map(int,input().split())) B_2 = list(map(int,input().split())) A = sorted(A_2) B = sorted(B_2) a = 0 b = 0 ans = math.inf while a < N and b < M: ans = min(ans, abs(A[a] - B[b])) if A[a] < B[b]: a += 1 else: b += 1 print(ans) 感想 AとBのリストをソートしてあげて、それぞれにカーソルを用意して(このコードでいうaとb)、一つずつ進めてあげるイメージらしいです。(片方のカーソルの位置を超えたタイミングで超えられた方のカーソルを進めていく) リストBの要素をソート後のリストAに入れて(ソートされている状態をくずさないで)、両隣との差を出していくのかなと考えていましたが、カーソルをイメージするやり方の方が個人的に分かりやすいなと感じました。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Databricks(Spark)にてDelta Lake形式・Parquet形式でAnalyze Tableを実行した際に格納されるデータの調査結果

概要 Databricks(Spark)にてDelta Lake形式・Parquet形式でAnalyze Tableを実行した際に格納されるデータの調査結果を共有します。Analyze Table実行後に、Hive Metastoreデータベース(MySQL)のTABLE_PARAMSテーブルとPARTITION_PARAMSテーブルにデータが格納され、SparkからDESC EXTENDEDにより確認できます。 統計情報の種類としては3種類あり、Delta Lake形式では、パーティションの統計情報をサポートしていないようです。ただ、Sparkにてパーティションの統計情報がどのように利用されるかを確認できませんでした。 テーブルの統計情報 カラムの統計情報 パーティションの統計情報 Sparkの統計情報に関しては下記の資料が参考になります。 引用元:Spark SQL Beyond Official Documentation 詳細は下記のGithub pagesのページをご確認ください。 コードを実行したい方は、下記のdbcファイルを取り込んでください。 https://github.com/manabian-/databricks_tecks_for_qiita/blob/main/tecks/survey_about_analyze_table/dbc/survey_about_analyze_table.dbc 実行環境 databricks runtime: 8.3.x-scala2.12 Python version: 3.8.8 Pyspark version: 3.1.2.dev0 データの確認 Parquet形式におけるテーブルの統計情報 DESC EXTENDEDの実行結果 DESC EXTENDED students_parquet; Hive Metastoreデータベースへのクエリ実行結果 SET @TABLE_NAME="students_parquet"; SELECT target.* FROM TABLE_PARAMS AS target INNER JOIN TBLS AS master on target.TBL_ID = master.TBL_ID and master.TBL_NAME = @TABLE_NAME ; Delta Lake形式におけるテーブルの統計情報 DESC EXTENDEDの実行結果 統計情報が確認できませんでした。 DESC EXTENDED students_delta; Hive Metastoreデータベースへのクエリ実行結果 SET @TABLE_NAME="students_delta"; SELECT target.* FROM TABLE_PARAMS AS target INNER JOIN TBLS AS master on target.TBL_ID = master.TBL_ID and master.TBL_NAME = @TABLE_NAME ; Parquet形式におけるカラムの統計情報 DESC EXTENDEDの実行結果 DESC EXTENDED students_delta student_id; Hive Metastoreデータベースへのクエリ実行結果 SET @TABLE_NAME="students_parquet"; SELECT target.* FROM TABLE_PARAMS AS target INNER JOIN TBLS AS master on target.TBL_ID = master.TBL_ID and master.TBL_NAME = @TABLE_NAME ; Delta Lake形式におけるカラムの統計情報 DESC EXTENDEDの実行結果 DESC EXTENDED students_parquet student_id; Hive Metastoreデータベースへのクエリ実行結果 SET @TABLE_NAME="students_delta"; SELECT target.* FROM TABLE_PARAMS AS target INNER JOIN TBLS AS master on target.TBL_ID = master.TBL_ID and master.TBL_NAME = @TABLE_NAME ; Parquet形式におけるパーティションの統計情報 DESC EXTENDEDの実行結果 DESC EXTENDED students_parquet PARTITION (student_id = 111111); Hive Metastoreデータベースへのクエリ実行結果 SET @TABLE_NAME="students_parquet"; SELECT target.* FROM PARTITION_PARAMS AS target INNER JOIN PARTITIONS AS master on target.PART_ID = master.PART_ID INNER JOIN TBLS AS master2 on master.TBL_ID = master2.TBL_ID and master2.TBL_NAME = @TABLE_NAME ; Delta Lake形式におけるパーティションの統計情報 DESC EXTENDEDの実行結果 データを確認できませんでした。 DESC EXTENDED students_delta PARTITION (student_id = 111111); Hive Metastoreデータベースへのクエリ実行結果 データを確認できませんでした。 SET @TABLE_NAME="students_delta"; --レコードを取得できない想定 SELECT target.* FROM PARTITION_PARAMS AS target INNER JOIN PARTITIONS AS master on target.PART_ID = master.PART_ID INNER JOIN TBLS AS master2 on master.TBL_ID = master2.TBL_ID and master2.TBL_NAME = @TABLE_NAME ;
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

サポートベクターマシンと畳み込みニューラルネットワーク(とディープラーニング)の比較

自己紹介、目的 機械学習をゼロから学んでいる者です。ここまでの学習のまとめとして、3つの手法を用いて画像分類を行いましたので、その結果と簡単な分析をブログに記載します。使用した3手法は以下です。 サポートベクターマシン ディープラーニングではない機械学習の代表格として。 ディープラーニング(CNN不使用) 結果としては全く上手くいかなかったのですが、比較のため記載します。 畳み込みニューラルネットワーク(CNN) ディープラーニングの画像分類における有力な手法です。 環境 MacBookAir  early 2020, RAM 8GB, CPU intel core i5 (非M1) OS: macOS Big Sur Ver. 11.4 docker上のJupyter NotebookのPython上にて行いました。ここで使用したdockerイメージには、tensorflow、scikit-learn、pandas、numpy、matplotlib、RandomUnderSamplerをインストールしています。この環境構築は以下の記事を参照してください。 結果概要 データセット Scikit-learnから入手できるLFW people 4名の画像を各120枚づつ用意。 70%をトレーニング、30%をテストに使用。 各手法における正解率 サポートベクターマシン:正解率85% ディープラーニング(CNN不使用):正解率25% *適当に分類しても1/4の確率で正解するので、意味ない・・・ 畳み込みニューラルネットワーク(CNN):正解率79% 実行内容詳細 Jupyter Notebookではセルごとにpythonを実行できましたので、それと同様に実行した順に記載します。 データセットのインポート from sklearn.datasets import fetch_lfw_people LFW peopleのデータセットのインポートです。初回は少し時間がかかります。 データ確認1 face_data = fetch_lfw_people(min_faces_per_person=120) X = face_data.data Y = face_data.target print("input data size:", X.shape) print("output data size:", Y.shape) print("label name:", face_data.target_names) 出力 input data size: (1031, 2914) output data size: (1031,) label name: ['Colin Powell' 'Donald Rumsfeld' 'George W Bush' 'Tony Blair'] 1行目でface_dataにデータを格納しています。LFW peopleのデータセットには5749名分、13,233枚の画像が格納されていますが、人物により画像数にばらつきがあります。min_faces_per_personを指定することで人物毎の最少画像数を指定でき、120を指定すると4名の画像のみとなります。 Xにデータを、Yにターゲットを代入しました。X.shapeでXのサイズを確認して、行数1031は1031枚の画像ということ、列数2914は画像毎のデータ数は2914データ(縦62ピクセル、横47ピクセル)ということがわかりました。 4名の人物の名前を確認したところ、Colin Powell、Donald Rumsfeld、George W Bush、Tony Blairとなりました。 データ確認2 import matplotlib.pyplot as plt fig, ax = plt.subplots(3, 4) plt.subplots_adjust(wspace=0.8, hspace=0.5) for i, axi in enumerate(ax.flat): axi.imshow(face_data.images[i], cmap='gray') axi.set(xticks=[], yticks=[], xlabel=face_data.target_names[face_data.target[i]]) plt.show() 出力 matplotlibを用いてどのような画像か確認しました。 サンプル数確認 for i in range(len(face_data.target_names)): print("{} has {} samples".format(face_data.target_names[i], (Y == i).sum())) 出力 Colin Powell has 236 samples Donald Rumsfeld has 121 samples George W Bush has 530 samples Tony Blair has 144 samples 各人物毎の画像数を確認しました。データ数に偏りがあるので、次のセルでこれを是正します。 アンダーサンプリングによるデータの偏り修正 from imblearn.under_sampling import RandomUnderSampler import pandas as pd df=pd.DataFrame(X) df['target']=Y strategy = {0:120, 1:120, 2:120, 3:120} rs=RandomUnderSampler(random_state=42, sampling_strategy = strategy) df_sampled,_=rs.fit_resample(df,df.target) print(df_sampled.target.value_counts()) 出力 3  120 2  120 1  120 0  120 Name: target, dtype: int64 RandomUnderSamplerを用いてデータの偏りを無くしました。 Donald Rumsfeldさんが121枚でしたので、キリのいい120枚で統一しました。 出力にて120枚ずつにアンダーサンプリングできていることが確認できます。 訓練データとテストデータへの分割 from sklearn.model_selection import train_test_split Y = df_sampled.target.values X = df_sampled.iloc[:,0:2914].values X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3,stratify=Y, random_state=0) print(X.shape) print(X_train.shape) print(y_train.shape) print(X_test.shape) print(y_test.shape) #念のため確認 df = pd.DataFrame(y_train) df.value_counts() 出力 (480, 2914) (336, 2914) (336,) (144, 2914) (144,) 3 84 2 84 1 84 0 84 dtype: int64 ダウンサンプリングしたデータフレームからXとYにデータを戻します。 そして大変便利なtrain_test_splitを用いて訓練データとテストデータに分割します。 データのshapeを確認そたところ、それぞれ所望のものとなっていることが確認できました。 さらに念のためy_trainのデータ数を確認したところ、一名につき84枚としっかりデータの偏りがないことを確認しました。 サポートベクターマシン(SVC)にて分類 from sklearn import svm from sklearn.metrics import accuracy_score svc = svm.SVC(random_state=42,C=1, kernel='rbf', gamma=0.0000001) svc.fit(X_train, y_train) y_pred_svc = svc.predict(X_test) print ("Accuracy: %.2f"%accuracy_score(y_test, y_pred)) print("predicted", y_pred_svc[0:20]) print("actual", y_test[0:20]) 出力 Accuracy: 0.85 predicted [3 1 3 1 1 0 1 3 0 3 0 3 3 2 1 0 0 3 1 1] actual [2 1 3 1 1 0 1 3 0 3 2 3 3 2 2 0 0 3 1 1] さすがサポートベクターマシン、精度は85%を記録しました。 記述量もたったこれだけ、なんて簡単なんでしょう!さらに私の環境では1秒足らずで結果を出力します。 出力のpredictedとactualでは、testデータの最初の20枚の正誤について確認しています。 結果については後でもう少し分析します。 ディープラーニング(CNN不使用)にて分類 from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Activation, Dense, Dropout, Input, BatchNormalization from sklearn.metrics import mean_squared_error import numpy as np import pandas as pd #Yデータをワンホットエンコーディング y_train_onehot = pd.get_dummies(y_train) y_test_onehot = pd.get_dummies(y_test) model = Sequential() model.add(Dense(2048, input_dim=2914)) model.add(Activation("relu")) model.add(Dropout(rate=0.2)) model.add(Dense(64)) model.add(Activation("relu")) model.add(Dropout(rate=0.2)) model.add(Dense(4)) model.add(Activation("softmax")) model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) history = model.fit(X_train, y_train_onehot, epochs= 50, batch_size= 32, verbose=1, validation_data=(X_test, y_test_onehot) ) #スコア出力 score = model.evaluate(X_test, y_test_onehot, verbose=1) print("evaluate loss: {0[0]}\nevaluate acc: {0[1]}".format(score)) pred = np.argmax(model.predict(X_test), axis=1) print("predicted", pred[0:20]) print("actual", y_test[0:20]) plt.plot(history.history["accuracy"], label="accuracy", ls="-", marker="o") plt.ylabel("accuracy") plt.xlabel("epoch") plt.legend(loc="best") plt.show() 出力 --epocについては省略-- evaluate loss: 1.386299729347229 evaluate acc: 0.25 predicted [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] actual [2 1 3 1 1 0 1 3 0 3 2 3 3 2 2 0 0 3 1 1] サポートベクターマシンでできるならCNNを使用しないディープラーニングでもできるでしょ!?との思いからやってみたものの、結果は散々でした。何度パラメータを変えても精度は25%になります。predictedを見ると分かるように、適当に結果を出力しているだけでした。どこかの設定ミスったのかな・・・。 ちなみに、ディープラーニングではYのラベルはワンホットエンコーディングをしなければエラーとなりました。多値分類でディープラーニングの出力層にsoftmaxを使用しているので、ワンホットエンコーディングが必要なようです。 畳み込みニューラルネットワーク(CNN) from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation from tensorflow.keras.layers import Conv2D, MaxPooling2D from tensorflow.keras.models import Sequential, load_model from tensorflow.keras.utils import to_categorical, plot_model from tensorflow.keras.callbacks import EarlyStopping import numpy as np import matplotlib.pyplot as plt #Yデータをワンホットエンコーディング y_train_onehot = pd.get_dummies(y_train) y_test_onehot = pd.get_dummies(y_test) #Xデータを画像形式に変換 #Convレイヤーは4次元配列をとる(バッチサイズx縦x横xチャンネル数) X_train = X_train.reshape(-1, 62, 47 , 1) X_test = X_test.reshape(-1, 62, 47 , 1) # モデルの定義 model = Sequential() model.add(Conv2D(filters=32, kernel_size=(3, 3), input_shape=(62,47,1))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.4)) model.add(Conv2D(filters=32, kernel_size=(3, 3))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.4)) model.add(Flatten()) model.add(Dense(128)) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(4)) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) #early stoppingも導入する earlyStopping = EarlyStopping(monitor = 'val_loss', patience = 50, mode ="auto") history = model.fit(X_train, y_train_onehot, batch_size=64, epochs=1000, verbose=1, validation_data=(X_test, y_test_onehot)) #スコア出力 score = model.evaluate(X_test, y_test_onehot, verbose=1) print("evaluate loss: {0[0]}\nevaluate acc: {0[1]}".format(score)) pred_cnn = np.argmax(model.predict(X_test), axis=1) print("predicted", pred_cnn[0:20]) print("actual", y_test[0:20]) plt.plot(history.history["accuracy"], label="accuracy", ls="-", marker="o") plt.ylabel("accuracy") plt.xlabel("epoch") plt.legend(loc="best") plt.show() 出力 --999/1000までのepocは省略-- Epoch 1000/1000 6/6 [==============================] - 1s 112ms/step - loss: 0.0164 - accuracy: 0.9881 - val_loss: 1.3945 - val_accuracy: 0.7917 5/5 [==============================] - 0s 18ms/step - loss: 1.3945 - accuracy: 0.7917 evaluate loss: 1.3944944143295288 evaluate acc: 0.7916666865348816 predicted [3 1 3 1 1 0 1 3 0 3 2 0 3 1 3 3 0 3 1 1] actual [2 1 3 1 1 0 1 3 0 3 2 3 3 2 2 0 0 3 1 1] いや、さすがCNN。これぞディープラーニング。epoc毎の精度の上昇は見ていて気落ちがいいです。しかしながら、trainデータには精度98〜100%が出たのですが、testデータでの精度は79%でした。完全に過学習してます。 結果について データ分析的なこともしたいので、サポートベクターマシン(SVC)と畳み込みニューラルネットワーク(CNN)がどのような誤分類をしたのか確認しました。 分析用のデータフレーム作成 result_svc = [] for i in range(len(y_test)): if y_pred_svc[i] != y_test[i]: result_svc.append([i, y_test[i], y_pred_svc[i]]) result_svc = pd.DataFrame(result_svc, columns = ["index number","actual","predict" ]) result_cnn = [] for i in range(len(y_test)): if pred_cnn[i] != y_test[i]: result_cnn.append([i, y_test[i], pred_cnn[i]]) result_cnn = pd.DataFrame(result_cnn, columns = ["index number","actual","predict" ]) このセルにてSVCとCNNのテストデータの予想が外れた結果のみを抽出して、result_svcとresult_cnnというpandasのデータフレームを作成しました。 サポートベクターマシン(SVC)の結果のまとめ result_svc.groupby(["actual","predict"]).count().set_axis(["回数"], axis='columns') 出力 "actual"と"predict"列でgroupbyを使用して、それぞれの回数を見てみました。例えば、表の一番上の結果は「0」の人を「1」の人に誤分類する回数は1回ということを表しています。ちなみにテストデータは各人物84枚です。 SVCの結果では、「2」の人を予想することが難しく、「2」の人を「0」の人に間違えることが多い(5回)ことがわかりました。 ちなみに、0〜3の人はそれぞれ['Colin Powell' 'Donald Rumsfeld' 'George W Bush' 'Tony Blair']です。 畳み込みニューラルネットワーク(CNN)の結果のまとめ result_cnn.groupby(["actual","predict"]).count().set_axis(["回数"], axis='columns') 出力 SVCと同様に"actual"と"predict"列でgroupbyを使用して、それぞれの回数を見てみました。CNNも「2」の人を分類するのが難しく、CNNの場合は「1」の人に誤分類することが多い(6回)ようです。 最後に CNNの精度を上げる手法として画像の水増しする方法もあるのですが、そこまでは行いませんでした。 ハイパーパラメータの調整と合わせて今後時間があったらやってみようと思います。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

~ ワーシャルフロイド法 ~ チートシート

目次 ワーシャルフロイド法とは 実装 もしかして:ダイクストラ法 はじめに チートシートの扱いついてはここを読んでください ワーシャルフロイド法とは わかりやすいサイト 全ての頂点間に対し、別の頂点を経由するとコストが低くならないかを調べることで、全頂点間の最小移動コストを求めるアルゴリズム 全頂点が始点となる場合に有効 計算量が大きいので、始点が1つに定まっている場合はダイクストラ法を使う コストが負の経路があっても使うことができる コストが負の閉路に注意 実装 問題(Atcoderだと簡単な問題がなかったので、AIZU ONLINE JUDGEより) Floyd–Warshall_algorithm.py N,M = map(int,input().split()) #頂点の数、辺の数 mat = [[99999999999999999999] * (N+1) for i in range(N+1)] #最短距離の候補を格納する配列(mat[A][B]:AからBの最短経路) for i in range(N+1): mat[i][i] = 0 for i in range(M): A,B,C = map(int,input().split()) #辺の始点、辺の終点、移動に必要なコスト mat[A][B] = C def warshall_floyd(): flag = 1 #収束判定用の変数 for i in range(N): for j in range(N): for k in range(N): #for i in range(1,N+1): #頂点の番号が1,2,3...と振られている場合(Atcoder用) #for j in range(1,N+1): #頂点の番号が1,2,3...と振られている場合(Atcoder用) #for k in range(1,N+1): #頂点の番号が1,2,3...と振られている場合(Atcoder用) if mat[i][j] > mat[i][k] + mat[k][j] and mat[i][k] != 99999999999999999999 and mat[k][j] != 99999999999999999999: mat[i][j] = mat[i][k] + mat[k][j] flag = 0 if mat[i][i] < 0: #コストが負の閉路がある場合の処理 print("NEGATIVE CYCLE") exit() return flag while True: check = warshall_floyd() if check == 1: break ans = [[0] * (N) for i in range(N)] #答えを格納するための配列 for i in range(N): for j in range(N): ans[i][j] = mat[i][j] #ans[i][j] = mat[i+1][j+1] #頂点の番号が1,2,3...と振られている場合(Atcoder用) if ans[i][j] == 99999999999999999999: ans[i][j] = "INF" [print(*ans[i]) for i in range(len(ans))] Atcoderと異なりインデックスが0始まりなので注意(Pythonと同じだからこっちの方がありがたいけど) コストが負の閉路(始点と終点が同じ頂点であるループ)があると、そこを永遠に周回することでコスト-∞を達成してしまうので注意 ライブラリshortest_path()を使うという手もあるけど、Atcoderだとなんかめんどくさいらしい
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

なろうの更新通知をLINEで受け取る。

背景とやったこと 更新頻度の低い小説をいちいち目視で更新確認するのが面倒だったので、最新話の更新を検知して通知をLINEで受け取れるツールを作成した。 プログラム実行環境 macOS Big Sur バージョン1.5 使用した技術および言語 LINE Notify なろう小説API Python3.9.1 crontab(プログラムの定期実行に使用) コード全体 なろうAPIで最新話数の取得を行い、更新されていた場合LINE Notifyで更新通知を行う。 syosetu_notify.py import requests from bs4 import BeautifulSoup import re def main(): #通知用LINE Notifyの設定 notify_url = 'https://notify-api.line.me/api/notify' notify_token = 'ここにトークンを貼り付け' headers = {"Authorization" : "Bearer "+ notify_token} #前回のプログラム実行時に保存された最新話の確認 f = open('/path/to/current_chapter_syosetu.txt', 'r') previous_chapter_number = f.readline() f.close() #小説の最新話の話数を取得する res = requests.get('https://api.syosetu.com/novelapi/api/?of=ga&ncode=小説コード') soup = BeautifulSoup(res.text, 'html.parser') text = soup.get_text() current_chapter = re.search('general_all_no: [0-9]+', text) current_chapter_text = current_chapter.group() current_chapter_num_search = re.search('[0-9]+', current_chapter_text) current_chapter_number = current_chapter_num_search.group() #保存された話数とサイトから取得した話数が異なっていれば更新されたと判断する if current_chapter_number != previous_chapter_number: notify = '最新話'+current_chapter_number+'が更新されました!' payload = {"message" : notify} requests.post(notify_url, headers = headers, params = payload) #Line通知時に、トーク内に小説のリンクを貼る(任意) syosetu_link = 'https://ncode.syosetu.com/小説コード/' payload = {"message" : syosetu_link} requests.post(notify_url, headers = headers, params = payload) #保存する話数を更新 f = open('/path/to/current_chapter_syosetu.txt','w') f.write(current_chapter_number) f.close() if __name__ == '__main__' : main() LINE Notifyの設定部分 #通知用LINE Notifyの設定 notify_url = 'https://notify-api.line.me/api/notify' notify_token = 'ここにトークンを貼り付け' headers = {"Authorization" : "Bearer "+ notify_token} トークンの入手 LINE Notify にログインして、マイページに移動すると画面下部にトークンの発行できるフォームがある。トークン発行のボタンを押し、LINE Notifyと連携するトークを選択する(選択したトークに通知が送信される)。 その後発行されたトークンを notify_token の部分に貼り付ける。 現在保存されている最新話の確認 #前回のプログラム実行時に保存された最新話の確認 f = open('/path/to/current_chapter_syosetu.txt', 'r') previous_chapter_number = f.readline() f.close() current_chapter_syosetu.txt に更新前の最新話数を保存する。 syosetu_notify.py の初回実行時に current_chapter_syosetu.txt が空のままだと話数の比較ができないため予め話数を記入しておく。2度目以降の実行では自動で更新されるので触らなくて良い。 previous_chapter_number に current_chapter_syosetu.txt から読み込んだ話数が格納される。 最新話数の取得 #小説の最新話の話数を取得する res = requests.get('https://api.syosetu.com/novelapi/api/?of=ga&ncode=小説コード') soup = BeautifulSoup(res.text, 'html.parser') text = soup.get_text() current_chapter = re.search('general_all_no: [0-9]+', text) current_chapter_text = current_chapter.group() current_chapter_num_search = re.search('[0-9]+', current_chapter_text) current_chapter_number = current_chapter_num_search.group() なろう小説API 詳しい案内はこちら。 今回は特定の小説に対して、APIを用いて最新話の取得を行なっている。 https://api.syosetu.com/novelapi/api/?of=ga&ncode=小説コード について of パラメータの ga で話の総数、ncode パラメータで小説ごとに割り振られたIDを指定する。 小説のIDは、なろうのサイトで該当する作品の小説情報を参照する。 current_chapter_number にAPIで取得した最新話数(正確には総話数)が格納される。 話数の比較 #保存された話数とサイトから取得した話数が異なっていれば更新されたと判断する if current_chapter_number != previous_chapter_number: notify = '最新話'+current_chapter_number+'が更新されました!' payload = {"message" : notify} requests.post(notify_url, headers = headers, params = payload) #LINE通知時に、トーク内に小説のリンクを貼る(任意) syosetu_link = 'https://ncode.syosetu.com/小説コード/' payload = {"message" : syosetu_link} requests.post(notify_url, headers = headers, params = payload) #保存する話数を更新 f = open('/path/to/current_chapter_syosetu.txt','w') f.write(current_chapter_number) f.close() 前述の previous_chapter_number と current_chapter_number を比較して、両者の値が異なっていれば小説が更新されたとみなしてLINEで通知する。 その後更新された話数を current_chapter_syosetu.txt に上書きする。 notify や syosetu_link の部分はLINEトーク中にメッセージとして表示される部分であり適宜いじって問題ない。 crontab 以上のプログラムをcrontabで自動実行する。なろうAPIはこの記事の執筆時点(2021年 8月1日)でアクセス数に特に制限はないが、プログラムを走らせるたびにサイトへアクセスが発生するのは確かなので、実行間隔はほどほどに(自分の場合は1日1回ペース)。 ふりかえり 作った後で調べてみたらメールを介して通知を受け取れるサービスは既に存在するようだった。まあLINEで受け取れる方が今風かなということでひとつ。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

TabNetを使ってみた

製造業出身のデータサイエンティストがお送りする記事 今回はTabNetを使ってみました。 TabNetとは TabNetは、ツリーベースモデルとディープニューラルネットワークの利点を持ち合わせた高パフォーマンスなモデルだそうです。 細かい部分は論文を参照して頂けますと幸いです。 TabNetの実装 今回もUCI Machine Learning Repositoryで公開されているボストン住宅の価格データを用いて予測モデルを構築します。 # ライブラリーのインポート import torch from torch import nn from torch.utils.data import DataLoader, Dataset import torch.optim as optim import torch.nn.functional as F from torch.optim.lr_scheduler import ReduceLROnPlateau from sklearn.model_selection import StratifiedKFold from pytorch_tabnet.tab_model import TabNetRegressor import os import random import pandas as pd import numpy as np import seaborn as sns import matplotlib.pyplot as plt %matplotlib inline # ボストンの住宅価格データ from sklearn.datasets import load_boston # 前処理 from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split # 評価指標 from sklearn.metrics import r2_score from sklearn.metrics import mean_absolute_error from sklearn.metrics import mean_squared_error def seed_everything(seed_value): random.seed(seed_value) np.random.seed(seed_value) torch.manual_seed(seed_value) os.environ["PYTHONHASHSEED"] = str(seed_value) if torch.cuda.is_available(): torch.cuda.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False seed_everything(10) # データセットの読込み boston = load_boston() # 説明変数の格納 df = pd.DataFrame(boston.data, columns = boston.feature_names) # 目的変数の追加 df['MEDV'] = boston.target # データの中身を確認 df.head() 次にデータセットを分割します(train, valid, test)。 # ランダムシード値 RANDOM_STATE = 10 # 学習データと評価データの割合 TEST_SIZE = 0.2 # 学習データと評価データを作成 x_train, x_test, y_train, y_test = train_test_split( df.iloc[:, 0 : df.shape[1] - 1], df.iloc[:, df.shape[1] - 1], test_size=TEST_SIZE, random_state=RANDOM_STATE, ) # trainのデータセットの2割をモデル学習時のバリデーションデータとして利用する x_train, x_valid, y_train, y_valid = train_test_split( x_train, y_train, test_size=TEST_SIZE, random_state=RANDOM_STATE ) 次にパラメータをセットします。詳細はRepositoryのReadmeに記載されております。 # モデルのパラメータ tabnet_params = dict( n_d=15, n_a=15, n_steps=8, gamma=0.2, seed=10, lambda_sparse=1e-3, optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2, weight_decay=1e-5), mask_type="entmax", scheduler_params=dict( max_lr=0.05, steps_per_epoch=int(x_train.shape[0] / 256), epochs=200, is_batch_level=True, ), verbose=5, ) まだ、各パラメータがモデルにどのように影響するのか把握できておりませんので、今後更に使い込んでみようと思います。 次にモデルの学習を行います。 # model model = TabNetRegressor(**tabnet_params) model.fit( X_train=x_train.values, y_train=y_train.values.reshape(-1, 1), eval_set=[(x_valid.values, y_valid.values.reshape(-1, 1))], eval_metric=["mae"], max_epochs=200, patience=30, batch_size=256, virtual_batch_size=128, num_workers=2, drop_last=False, loss_fn=torch.nn.functional.l1_loss, ) TabNetでは、変数重要度も算出できます。細かい算出ロジックはまだ理解できておりません。 # Feature Importance feat_imp = pd.DataFrame(model.feature_importances_, index=boston.feature_names) feature_importance = feat_imp.copy() feature_importance["imp_mean"] = feature_importance.mean(axis=1) feature_importance = feature_importance.sort_values("imp_mean") plt.tick_params(labelsize=18) plt.barh(feature_importance.index.values, feature_importance["imp_mean"]) plt.title("feature_importance", fontsize=18) またTabNetでは、マスクという横軸を使用した特徴量、縦軸にデータを表し、重要な特徴量を濃淡で表す機能もあります。 # Mask(Local interpretability) explain_matrix, masks = model.explain(x_test.values) fig, axs = plt.subplots(1, 3, figsize=(10, 7)) for i in range(3): axs[i].imshow(masks[i][:25]) axs[i].set_title(f"mask {i}") 最後に予測を行います。 # TabNet推論 y_pred = model.predict(x_test.values) # 評価 def calculate_scores(true, pred): """全ての評価指標を計算する Parameters ---------- true (np.array) : 実測値 pred (np.array) : 予測値 Returns ------- scores (pd.DataFrame) : 各評価指標を纏めた結果 """ scores = {} scores = pd.DataFrame( { "R2": r2_score(true, pred), "MAE": mean_absolute_error(true, pred), "MSE": mean_squared_error(true, pred), "RMSE": np.sqrt(mean_squared_error(true, pred)), }, index=["scores"], ) return scores scores = calculate_scores(y_test, y_pred) print(scores) 出力結果は下記のようになります。 R2 MAE MSE RMSE scores 0.90156 2.466226 10.294959 3.208576 さいごに 最後まで読んで頂き、ありがとうございました。 今回はTabNetを使ってみました。 訂正要望がありましたら、ご連絡頂けますと幸いです。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Databricks(Spark)にてPythonによりデータフレーム・テーブルのサイズを確認する方法

概要 Databricks(Spark)にてPythonによりデータフレーム・テーブルのサイズを確認する方法を紹介します。 DatabricksのドキュメントにてScalaでテーブルサイズを確認する方法が紹介されており、Python(PySpark)に置き換えました。 引用元:テーブルのサイズの検索-Azure Databricks - Workspace | Microsoft Docs Delta lake形式とその他の形式で取得方法が異なり、Delta Lake形式ではdescribe detailを利用するように変更しております。おそらくOSSのdelta lakeを利用している場合にもデータが取得できるはずです。 Delta Lake形式におけるデータフレーム・テーブルのサイズを確認する方法 Scalaでの実行例 %scala import com.databricks.sql.transaction.tahoe._ val table_size = DeltaLog.forTable(spark, "dbfs:/tmp/qiita/flights/summary-data/delta").snapshot.sizeInBytes println(s"Total file size (bytes): ${table_size}") テーブル名を指定する方法 # テーブル名を指定する方法 table_name = 'flights_summary_data' table_size = spark.sql(f'describe detail {table_name}').select('sizeInBytes').first()[0] print(f"Total file size (bytes): {table_size}") テーブルのファイルパスを指定する方法 # テーブルのファイルパスを指定する方法 table_location = '"dbfs:/tmp/qiita/flights/summary-data/delta"' table_size = spark.sql(f'describe detail {table_location}').select('sizeInBytes').first()[0] print(f"Total file size (bytes): {table_size}") Delta Lake形式以外(Parquet・CSV・Json等)におけるデータフレーム・テーブルのサイズを確認する方法 Scalaでの実行例 %scala spark.read.table("flights_summary_data_parquet").queryExecution.analyzed.stats.sizeInBytes テーブル名を指定する方法 # テーブル名を指定する方法 table_name = 'flights_summary_data_parquet' table_size = spark.read.table(table_name)._jdf.queryExecution().analyzed().stats().sizeInBytes() print(f"Total file size (bytes): {table_size}") テーブルのファイルパスを指定する方法 # テーブルのファイルパスを指定する方法 table_location = "/databricks-datasets/learning-spark-v2/flights/summary-data/parquet/*" df = (spark .read .format("parquet") .option("inferSchema", "True") .load(table_location) ) table_size = df._jdf.queryExecution().analyzed().stats().sizeInBytes() print(f"Total file size (bytes): {table_size}") その他 CSV形式のデータフレームのサイズを確認する方法 # CSVファイルのディレクトリと合計サイズを表示 file_list =dbutils.fs.ls("/databricks-datasets/learning-spark-v2/flights/summary-data/csv") display(file_list) spark.createDataFrame(file_list).groupBy().sum('size').display() # テーブルのファイルパスを指定する方法 table_location = "/databricks-datasets/learning-spark-v2/flights/summary-data/csv/*" schema = """ DEST_COUNTRY_NAME STRING ,ORIGIN_COUNTRY_NAME STRING ,count INT """ df = (spark .read .format("csv") .schema(schema) .option("header", "true") .option("inferSchema", "False") .load(table_location) ) table_size = df._jdf.queryExecution().analyzed().stats().sizeInBytes() print(f"Total file size (bytes): {table_size}")
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Databricks(Spark)にてPythonによりテーブルのサイズを確認する方法

概要 Databricks(Spark)にてPythonによりテーブルのサイズを確認する方法を紹介します。 DatabricksのドキュメントにてScalaでテーブルサイズを確認する方法が紹介されており、Python(PySpark)に置き換えました。 引用元:テーブルのサイズの検索-Azure Databricks - Workspace | Microsoft Docs Delta lake形式とその他の形式で取得方法が異なり、Delta Lake形式ではdescribe detailを利用するように変更しております。おそらくOSSのdelta lakeを利用している場合にもデータが取得できるはずです。 Delta Lake形式におけるデータフレーム・テーブルのサイズを確認する方法 Scalaでの実行例 %scala import com.databricks.sql.transaction.tahoe._ val table_size = DeltaLog.forTable(spark, "dbfs:/tmp/qiita/flights/summary-data/delta").snapshot.sizeInBytes println(s"Total file size (bytes): ${table_size}") テーブル名を指定する方法 # テーブル名を指定する方法 table_name = 'flights_summary_data' table_size = spark.sql(f'describe detail {table_name}').select('sizeInBytes').first()[0] print(f"Total file size (bytes): {table_size}") テーブルのファイルパスを指定する方法 # テーブルのファイルパスを指定する方法 table_location = '"dbfs:/tmp/qiita/flights/summary-data/delta"' table_size = spark.sql(f'describe detail {table_location}').select('sizeInBytes').first()[0] print(f"Total file size (bytes): {table_size}") Delta Lake形式以外(Parquet・CSV・Json等)におけるデータフレーム・テーブルのサイズを確認する方法 Scalaでの実行例 %scala spark.read.table("flights_summary_data_parquet").queryExecution.analyzed.stats.sizeInBytes テーブル名を指定する方法 # テーブル名を指定する方法 table_name = 'flights_summary_data_parquet' table_size = spark.read.table(table_name)._jdf.queryExecution().analyzed().stats().sizeInBytes() print(f"Total file size (bytes): {table_size}") テーブルのファイルパスを指定する方法 # テーブルのファイルパスを指定する方法 table_location = "/databricks-datasets/learning-spark-v2/flights/summary-data/parquet/*" df = (spark .read .format("parquet") .option("inferSchema", "True") .load(table_location) ) table_size = df._jdf.queryExecution().analyzed().stats().sizeInBytes() print(f"Total file size (bytes): {table_size}") その他 CSV形式のデータフレームのサイズを確認する方法 # CSVファイルのディレクトリと合計サイズを表示 file_list =dbutils.fs.ls("/databricks-datasets/learning-spark-v2/flights/summary-data/csv") display(file_list) spark.createDataFrame(file_list).groupBy().sum('size').display() # テーブルのファイルパスを指定する方法 table_location = "/databricks-datasets/learning-spark-v2/flights/summary-data/csv/*" schema = """ DEST_COUNTRY_NAME STRING ,ORIGIN_COUNTRY_NAME STRING ,count INT """ df = (spark .read .format("csv") .schema(schema) .option("header", "true") .option("inferSchema", "False") .load(table_location) ) table_size = df._jdf.queryExecution().analyzed().stats().sizeInBytes() print(f"Total file size (bytes): {table_size}")
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

テスト技法ツールGIHOZのペアワイズテストの組み合わせ網羅度合いを可視化するツール

はじめに テスト技法ツールGIHOZのペアワイズテストで作成した組み合わせ結果を入力として、組み合わせの網羅度合いを可視化するツールを作ってみました。この記事ではそのツールを紹介します。 ペアワイズテストは入力パラメータの組み合わせを作るテスト技法で、2つのパラメータ間の値の組み合わせを網羅できます。オールペア法とも言いますが、この記事ではGIHOZの表記にならってペアワイズテストと記載します。 ツールの概要 ペアワイズテストの組み合わせ生成結果から、2パラメータ間の組み合わせの網羅度合いを可視化するためのツールです。縦軸と横軸にパラメータと値を並べたマトリクスを生成し、マトリクスの交点にあたる組み合わせが何回登場しているかを可視化することができます。具体的には、以下のような表を生成します。この表を何と呼ぶのが正しいのか分かりませんが、「総当たり表」と呼ばれることもあるようです。 GIHOZの組み合わせ生成結果を信じるのであれば総当たり表は特に必要ないですが、第三者などから「本当に2パラメータ間の組み合わせの網羅ができているの?」と問われたときや、制約をかけて組み合わせ生成したときに思った通りの組み合わせ結果になっているか不安なときなどに、総当たり表を見ることで、組み合わせの網羅度合いを確認できます。 ソースコード・動作確認環境 ソースコードはGitHubのリポジトリで公開しています。 以下の環境で動作することを確認しました。 OS: Windows10 言語: Python 3.9.5 使い方 まずPython3をインストールし、次にpandasというライブラリをインストールします。pandasはpipコマンドが使える方は以下のコマンドでインストールできます。Anacondaなどをご利用の方はご自分でコマンドを調べてください。 $ pip install pandas ここまでで環境の準備は完了です。 GIHOZのペアワイズテスト画面で、適当に組み合わせを生成して、結果をcsvファイルとしてダウンロードします。 上記のGitHubのリポジトリからソースコードをダウンロードして、コマンドプロンプトやPowerShellでソースコードの保存先を開いて、以下のコマンドを打てば実行できます。 $ python check_combinations.py [ペアワイズテストのCSVファイル名] [ペアワイズテストのCSVファイル名]は、GIHOZからダウンロードした直後はPair-wise_generated-results_yyyy-mm-dd.csvといったファイル名になっています。(yyyy-mm-ddの部分はダウンロードした日付です) 実行に成功するとchecked.csvという名前のファイルが生成されます。 checked.csvを開いてそのまま数字を確認しても良いですが、Excelで開いて条件付き書式のカラースケールを指定すると、どの組み合わせがどの程度登場したか、視覚的に確認できます。対角線上のセルは、同じパラメータ同士の組み合わせは存在しないので「0」になっており、カラースケールで赤くなっています。 適当に制約をかけて組み合わせ生成した結果から総当たり表を作成すると、対角線上以外にも、登場しない組み合わせが存在することが分かります。 おわりに GIHOZのペアワイズテストの生成結果から組み合わせの網羅度合いを可視化するツールを作ってみた、という記事でした。ペアワイズテストのツールとして有名な「PICT」で組み合わせを生成した際にも、結果をcsv形式で保存してやれば、このツールで同じように網羅度合いを可視化できます。様々なパラメータを組み合わせてテストを行う際の一助になればと思います。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

インストール無しにWEBブラウザで動作するPython環境"Pyodide"を使ってみた

Pyodideとは WebAssembly技術を用いてWEBブラウザ上でPython3を動かすことができる素晴らしいソフトウェアです。WEBブラウザでPyodideを簡単に体験できるサイトが用意されており、これは面白そうだと思いました。 WEBエディタAceとの組みあわせ 本来のPyodideの目的はWEBブラウザ上でのPythonでの科学計算だと思いますので、エディタとしてはIodideのようにグラフ表示の要求優先度が高そうです。 とはいえ、コンソール上で単純にPythonが動くだけでも応用は利きそうだと思いましたので、今回はPythonコードを記述するエディタとしてAceを利用します。 simple版コードは下記のようになります。Aceエディタの領域とPyodide実行ボタンを縦に並べて、JavaScript内で初期化、実行するだけです。実行ログはF12キーで出るブラウザのコンソールを見てください。下記コードを index.htmlにコピーして、ブラウザで開くとNumpyのサンプルが動きます。 <!DOCTYPE html> <html lang="en-us"> <head> <meta charset="utf-8"> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> <title>pyodide test (simple)</title> <style type="text/css" media="screen"> #editor { width:100%; height:500px; margin-top:auto; margin-bottom:auto; } </style> <script src="https://cdn.jsdelivr.net/pyodide/v0.17.0/full/pyodide.js"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/ace/1.4.12/ace.js" type="text/javascript" charset="utf-8"></script> </head> <body style="text-align: center" > <div id="layout"> <div id="editor">import numpy as np a=np.array([1,2,3]) b=np.array([4,5,6]) print(a+b) </div> <br/> <input class="button" id="runBtn" style="background-color: #4CAF50;" type="button" value="Run Script" onclick="RunCode();" /> </div> <script> // init Ace Editor var editor = ace.edit("editor"); editor.setTheme("ace/theme/twilight"); editor.session.setMode("ace/mode/python"); // init Pyodide async function main(){ await loadPyodide({ indexURL : 'https://cdn.jsdelivr.net/pyodide/v0.17.0/full/' }); } let pyodideReadyPromise = main(); async function RunCode() { await pyodideReadyPromise; try { code = editor.getValue(); let output = await pyodide.runPythonAsync(code); } catch(err) { console.log(err); } } </script> </body> </html> 実行画面は下記です。numpyをimportするのに少し時間がかかりますが、ちゃんと答えが出ていますね。このほか、色々なライブラリをサポートしているとは思いますが、何がどこまで動くかは私もよくわかってませんのでぜひご自分でも試してみてください。 ファイルのロードセーブ Aceを使う場合に、Ace自体はファイルのロードセーブがサポートされていないので色々調べないといけません。ロードは簡単なのですが、セーブはできないためブラウザのダウンロード機能を利用する形で実装したのが下記のコードになります。レイアウトもsimple版よりは整えてます。 <!DOCTYPE html> <html lang="en-us"> <head> <meta charset="utf-8"> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> <title>pyodide test</title> <style type="text/css" media="screen"> #editor { width:100%; height:500px; margin-top:auto; margin-bottom:auto; } #editorGroup { position:absolute; margin-left:810px; margin-right:0; margin-top:auto; margin-bottom:auto; height:600px; } .button { width:120px; height:60px; background-color: #AAAAAA; border: none; color: white; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; border-radius: 4px; } #layout { float: left; width:800px; background-color: #FFFFFF; } .clear { clear:both; } table { border-collapse: collapse; /* セルの境界線を共有 */ } td { border: 1px solid black; /* 表の罫線(=セルの枠線) */ padding: 0.5em 1em; /* セル内側の余白量 */ } body { background: #eeeeee; font-family: Meiryo; } h1 { border: none; text-align: center; text-decoration: none; margin: 4px 2px; } </style> <script src="https://cdn.jsdelivr.net/pyodide/v0.17.0/full/pyodide.js"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/ace/1.4.12/ace.js" type="text/javascript" charset="utf-8"></script> </head> <body style="text-align: center" > <div id="layout"> <h1>Python Editor</h1> <div id="editor">import numpy as np a=np.array([1,2,3]) b=np.array([4,5,6]) print(a+b) </div> <br/> <table id="control_panel" style="width:800px" > <tr> <td rowspan="2"> <input class="button" id="runBtn" style="background-color: #4CAF50;" type="button" value="Run Script" onclick="RunCode();" /> </td> <td>Load File: <input type="file" id="loadBtn" onchange="Load(this.files);this.value=null;return false;" value="Open" /> </td> </tr> <tr> <td> <input id="inputFileNameToSaveAs"></input> <input class="button" id="saveasBtn" style="background-color: #728edb;height:30px;" type="button" value="Download" onclick="Save();" /> </td> </tr> </table> </div> <script> // init Ace Editor var editor = ace.edit("editor"); editor.setTheme("ace/theme/twilight"); editor.session.setMode("ace/mode/python"); // init Pyodide async function main(){ await loadPyodide({ indexURL : 'https://cdn.jsdelivr.net/pyodide/v0.17.0/full/' }); } let pyodideReadyPromise = main(); async function RunCode() { await pyodideReadyPromise; try { code = editor.getValue(); let output = await pyodide.runPythonAsync(code); } catch(err) { console.log(err); } } function Load(files) { var file = files[0] console.log("Load:" + file); if (!file) return; reader = new FileReader(); reader.onload = function() { editor.session.setValue(reader.result) } reader.readAsText(file) } function Save() { var textToSaveAsBlob = new Blob([editor.getValue()], {type:"text/plain"}); var textToSaveAsURL = window.URL.createObjectURL(textToSaveAsBlob); var fileNameToSaveAs = document.getElementById("inputFileNameToSaveAs").value; var downloadLink = document.createElement("a"); downloadLink.download = fileNameToSaveAs; downloadLink.innerHTML = "Download File"; downloadLink.href = textToSaveAsURL; downloadLink.onclick = function() { document.body.removeChild(event.target); } downloadLink.style.display = "none"; document.body.appendChild(downloadLink); downloadLink.click(); } </script> </body> </html> 実行画面はこんな感じです。 (追記) MatplotLibの図をブラウザに出力する方法 使う予定はないもののMatplotlibの絵を出せないのは気持ち悪いので、StackOverflowの記事を参考にしてsimple版コードを改造して出せるようにしてみました。ざっくり説明すると、id="fig"に対してPython側でPNG画像をBase64に変換して書き込んでいるようです。from js import document でJavaScriptと連携できることが分かって満足。 ついでにStatusも出せるようにしてみました。id='status' を状態に応じて変えるようにしているだけですが、状況がわかりやすくなりました。 <!DOCTYPE html> <html lang="en-us"> <head> <meta charset="utf-8"> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> <title>pyodide test (simple)</title> <style type="text/css" media="screen"> #editor { width:100%; height:500px; margin-top:auto; margin-bottom:auto; } </style> <script src="https://cdn.jsdelivr.net/pyodide/v0.17.0/full/pyodide.js"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/ace/1.4.12/ace.js" type="text/javascript" charset="utf-8"></script> </head> <body style="text-align: center" > <div id="layout"> <div id="editor">from js import document import numpy as np import scipy.stats as stats import matplotlib.pyplot as plt import io, base64 def generate_plot_img(): # get values from inputs mu = 1 sigma = 1 # generate an interval x = np.linspace(mu - 3*sigma, mu + 3*sigma, 100) # calculate PDF for each value in the x given mu and sigma and plot a line plt.plot(x, stats.norm.pdf(x, mu, sigma)) # create buffer for an image buf = io.BytesIO() # copy the plot into the buffer plt.savefig(buf, format='png') buf.seek(0) # encode the image as Base64 string img_str = 'data:image/png;base64,' + base64.b64encode(buf.read()).decode('UTF-8') # show the image img_tag = document.getElementById('fig') img_tag.src = img_str buf.close() generate_plot_img() </div> <br/> <input class="button" id="runBtn" style="background-color: #4CAF50;" type="button" value="Run Script" onclick="RunCode();" /> <br/> Status: <strong id='status'>Initializing...</strong> <br> <img id="fig" /> </div> <script> // init Ace Editor var editor = ace.edit("editor"); editor.setTheme("ace/theme/twilight"); editor.session.setMode("ace/mode/python"); // init Pyodide async function main(){ await loadPyodide({ indexURL : 'https://cdn.jsdelivr.net/pyodide/v0.17.0/full/' }).then(()=>document.getElementById('status').innerHTML='Start'); } let pyodideReadyPromise = main(); async function RunCode() { await pyodideReadyPromise; try { document.getElementById('status').innerHTML='Executing...'; code = editor.getValue(); let output = await pyodide.runPythonAsync(code).then(()=>document.getElementById('status').innerHTML='Done!'); } catch(err) { console.log(err); } } </script> </body> </html> 最後に もうちょっと苦戦するかなと思ったんですが、意外にあっさり動いたので驚きました。最近のWEB系技術の進歩はすさまじいですね。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Progateで作ったWebアプリをDjangoで作ってみる2! Part7 -DetailView編-

目標物の確認 ProgateのNode.jsコースで作ったブログアプリと同じものをDjangoで作ってみます。 Djangoでのアプリ開発の一連の流れを整理するために記していきます。 完成イメージ 記事画面の作成 個々の記事はDetailViewというViewで表示させます。 まずはHTMLファイル(article.html)を作成します。 Part6で作成したbase.htmlを継承して(☆1)、{% block content %} {% endblock content %}の中に個別記事を表示するコードを書いていきます(☆2)。 記事が「全員(object.category == 'all')」の場合、コンテンツが表示されるようにします(☆3)。記事が、「会員限定(object.category == 'limited')」の場合、タイトルは表示されるが、コンテンツはログインされている場合(user.is_authenticated)は表示されるようにします(☆4)。ログインされていない場合は、コンテンツは表示されず、ログイン画面へのリンクが表示されるようします(☆4)。 article.html {% extends 'blog/base.html' %} <!--☆1--> {% load static %} {% block content %} <!--☆2--> <main> <div class="article"> {% if object.category == 'all' %} <!--☆3--> <h1>{{ object.title }}</h1> <p>{{ object.content }}</p> {% endif %} <!--☆3--> {% if object.category == 'limited' %} <!--☆4--> <i>会員限定</i> <h1>{{ object.title }}</h1> <!--☆4--> {% if user.is_authenticated %} <!--☆4--> <p>{{ object.content }}</p> <!--☆4--> {% else %} <!--☆4--> <div class="article-login"> <p>今すぐログインしよう!</p> <p>記事の続きは<br>ログインすると読むことができます</p> <img src="{% static 'blog/login.svg' %}"> <a class="btn" href="{% url 'blog:login' %}">ログイン</a> </div> {% endif %} <!--☆4--> {% endif %} <!--☆4--> </div> </main> <footer> <a class="btn sub" href="{% url 'blog:list' %}">一覧にもどる</a> </footer> {% endblock content %} <!--☆2--> views.pyファイルを作成していきます(☆5)。 blogapp/blog/views.py from blog.models import BlogModel from django.shortcuts import render, redirect from django.views.generic import TemplateView, ListView, DetailView # ☆5 from django.contrib.auth.models import User from django.db import IntegrityError from django.contrib.auth import authenticate, login, logout from .models import BlogModel class BlogTop(TemplateView):... def signupview(request):... def loginview(request):... def logoutview(request):... class BlogList(ListView):... class BlogArticle(DetailView): # ☆5 template_name = 'blog/article.html' model = BlogModel urls.pyファイルも作成します(☆6)。 <int:pk>が、テーブルに入っているデータを具体的に指定する上で使われるコードです(☆6)。 blogapp/blog/urls.py from django.urls import path from .views import BlogTop, signupview, loginview, logoutview, BlogList, BlogArticle # ☆6 app_name = 'blog' urlpatterns = [ path('', BlogTop.as_view(), name='top'), path('signup/', signupview, name='signup'), path('login/', loginview, name='login'), path('logout/', logoutview, name='logout'), path('list/', BlogList.as_view(), name='list'), path('article/<int:pk>/', BlogArticle.as_view(), name='article') # ☆6 ] ログインの有無でページの表示が変わっています。 成功です!
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【物体検出】YOLOv5でYouTube動画から物体検出で電車を抽出してみる

はじめに YOLOv5でYouTube動画から直接物体検出する手順について備忘録としてまとめました。 Google Colaboratory上でYOLOv5とtorch hubを使用して、動画から直接物体検出を試していきます。 成果物 以下に示します。 YOLOv5の自作モデルによりYouTube動画から直接電車の抽出を行います。 結果を動画にしましたので以下よりご確認ください。 YOLOv5におけるモデル作成 今回試す電車抽出モデルの作成に関する内容は以下の通りです。 環境構築 まずはGoogle Colaboratoryを起動します。 !git clone https://github.com/ultralytics/yolov5 %cd yolov5/ !pip install -qr requirements.txt これで準備は完了です。 ①学習済みモデルでサクッと試してみる まずは実際にYouTube動画から物体検出を試してみます。 学習済みモデルを使用するので、誰でもすぐに再現することが可能です。 今回使用する動画は以下のものです。 以下の通り実行してみます。 import torch !python detect.py --source 'https://youtu.be/igh_evKnkFkhttps://youtu.be/igh_evKnkFk' しばらくすると結果が以下の場所に保存されます。 /content/yolov5/runs/detect/exp/ YouTube動画から物体検出ができました。 学習済みモデルを使用しているので電車以外にも人間などを検出しています。 ②学習済みモデルでサクッと試してみる 先程の動画に以前作成したモデルで再度検証を行います。 学習モデルはあらかじめ以下の場所にアップしておきます。 /content/yolov5/ 以下の通り実行します。 import torch !python detect.py --source 'https://youtu.be/igh_evKnkFkhttps://youtu.be/igh_evKnkFk' --weights /content/yolov5/best.pt 実際に電車の検出ができました。 結果 ①と②の抽出結果は以下の通り動画にまとめました。 まとめ YOLOv5でYouTube動画から直接物体検出する手順をまとめました。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【物体検出】Google ColabでYOLOv5を使ってYouTube動画から物体検出で電車を抽出してみる

はじめに YOLOv5でYouTube動画から直接物体検出する手順について備忘録としてまとめました。 Google Colaboratory上でYOLOv5とtorch hubを使用して、YouTube動画から直接物体検出を試していきます。 目標(成果物) 以下に示します。 YOLOv5の自作モデルによりYouTube動画から直接電車の抽出を行います。 結果を動画にしましたので以下よりご確認ください。 YOLOv5におけるモデル作成 今回試す電車抽出モデルの作成は以前作成したものを使用します。詳細は以下の記事を参照してください。 環境構築 まずはGoogle Colaboratoryを起動して、Yolov5の準備を行います。 !git clone https://github.com/ultralytics/yolov5 %cd yolov5/ !pip install -qr requirements.txt これで準備は完了です。 ①学習済みモデルでサクッと試してみる まずは実際にYouTube動画から物体検出を試してみます。 学習済みモデルを使用するので、誰でもすぐに再現することが可能です。 今回使用する動画は以下のものです。 この動画に対して物体検出をしてます。 以下の通り実行します。 import torch !python detect.py --source 'https://youtu.be/igh_evKnkFk' しばらくすると結果が以下の場所に結果が保存されます。 /content/yolov5/runs/detect/exp/ YouTube動画から物体検出ができました。 学習済みモデルを使用しているので電車以外にも人間などを検出しています。 ②自作モデルで試してみる 先程の動画に対して、以前作成したモデルで再度検証を行います。 学習モデルはあらかじめ以下の場所にアップしておきます。 /content/yolov5/ 以下の通り実行します。 引数weightsを与えることで、指定することができます。 import torch !python detect.py --source 'https://youtu.be/igh_evKnkFk' --weights /content/yolov5/best.pt しばらくすると以下の場所に結果が保存されます。 /content/yolov5/runs/detect/exp2/ 実際に電車の検出ができました。 結果 ①と②の抽出結果は以下の通り動画にまとめましたので、興味のある方はご覧下さい。 まとめ YOLOv5でYouTube動画から直接物体検出する手順をまとめました。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Pythonでデータの挙動を見やすくする可視化ツールを作成してみた まとめ編

はじめに 私は今まで、Pythonを使ってデータの可視化ツールを作成してきました。 いくつかはQiitaにも投稿しましたが、記事やコードが分散していて分かりづらいと感じたので、全機能をまとめて「seaborn-analyzer」としてライブラリ化し、概要を本記事にまとめました! 「実際のデータ分析を通じて欲しいと思った機能」を詰め込んだので、実務で役立つ機能がいくつもあるかと思います。 ぜひご活用頂ければと思います 構成 本ツールは以下のクラスからなります。 使用法の詳細は、「使用法リンク」中の各記事、および後述の「各機能の解説」をご参照ください クラス名 パッケージ名 概要 使用法リンク CustomPairPlot custom_pair_plot.py 散布図行列と相関係数行列を同時に表示 リンク hist custom_hist_plot.py ヒストグラムと各種分布のフィッティング リンク classplot custom_scatter_plot.py 分類境界およびクラス確率の表示 リンク regplot custom_scatter_plot.py 相関・回帰分析の散布図・ヒートマップ表示 リンク コードはGitHubにもアップロードしております おすすめ機能 個人的に使用頻度の高い機能を紹介します おすすめ1:CustomPairPlot.pairanalyzer 相関係数と散布図行列を一括表示します。 分析の初期段階でデータを一括で可視化したいときにオススメです。 Rのggplot2ではほぼ同様の図が出力可能ですが、なぜかPythonには同様のツールがなかったので、作成しました。 散布図では表示が重なり見辛い離散変数は、自動で箱ひげ図とバブルチャートに変更する機能も追加しています。 from seaborn_analyzer import CustomPairPlot import seaborn as sns titanic = sns.load_dataset("titanic") cp = CustomPairPlot() cp.pairanalyzer(titanic, hue='survived') おすすめ2:hist.plot_normality 正規性検定の結果をヒストグラムとともに表示します 正規分布かどうかの判断は曖昧な判定となりがちですが、 定量的な裏付けを付けたい時にオススメです from seaborn_analyzer import hist from sklearn.datasets import load_boston import pandas as pd df = pd.DataFrame(load_boston().data, columns= load_boston().feature_names) hist.plot_normality(df, x='LSTAT', norm_hist=False, rounddigit=5) おすすめ3:regplot.linearplot 相関係数、P値、回帰式を同時算出し、散布図と一緒に表示します。。 頑張ればExcelでも算出できる簡単な指標ですが、 全てを同時算出してくれるツールは意外とありません。 多くの情報を1枚の図に集約できるので、ループ実行して多くの変数の相関関係を定量的かつ効率的に見たい時にオススメです from seaborn_analyzer import regplot import seaborn as sns iris = sns.load_dataset("iris") regplot.linear_plot(x='petal_length', y='sepal_length', data=iris) おすすめ4:regplot.regression_heat_plot 説明変数が2次元以上の回帰モデルは可視化が難しいですが、 このメソッドを使えば4次元までならそれなりに直感的な可視化が可能となります。 2次元説明変数での可視化例 import pandas as pd from sklearn.linear_model import LinearRegression from seaborn_analyzer import regplot df_temp = pd.read_csv(f'./sample_data/temp_pressure.csv') regplot.regression_heat_plot(LinearRegression(), x=['altitude', 'latitude'], y='temperature', data=df_temp) 4次元説明変数での可視化例 import pandas as pd from sklearn.linear_model import LinearRegression from seaborn_analyzer import regplot df = pd.read_csv(f'./sample_data/osaka_metropolis_english.csv') regplot.regression_heat_plot(LinearRegression(), x=['2_between_30to60', '3_male_ratio', '5_household_member', 'latitude'], y='approval_rate', data=df, pair_sigmarange = 0.5, rounddigit_x1=3, rounddigit_x2=3) また、クロスバリデーションの可視化や誤差上位の表示など、モデルの傾向把握に便利な機能を多数実装しています。詳細はこちらの記事を参照ください クロスバリデーション&誤差上位を表示 import pandas as pd from sklearn.linear_model import LinearRegression from seaborn_analyzer import regplot df_temp = pd.read_csv(f'./sample_data/temp_pressure.csv') regplot.regression_heat_plot(LinearRegression(), cv=2, display_cv_indices=[0, 1], rank_number=3, rank_col='city', x=['altitude', 'latitude'], y='temperature', data=df_temp) 必要要件 Python >=3.6 Numpy >=1.20.3 Pandas >=1.2.4 Matplotlib >=3.3.4 Scipy >=1.6.3 Scikit-learn >=0.24.2 インストール方法 pipからインストールできます。 $ pip install seaborn-analyzer 各機能の解説 クラスおよびメソッドごとに機能と引数を解説します CustomPairPlotクラス (custom_pair_plot.py) 散布図行列と相関係数行列を同時に表示します。 1個のクラス「CustomPairPlot」からなります ・CustomPairPlotクラス内のメソッド一覧 メソッド名 機能 pairanalyzer 散布図行列と相関係数行列を同時に表示します pairanalyzerメソッド 実行例 from seaborn_analyzer import CustomPairPlot import seaborn as sns titanic = sns.load_dataset("titanic") cp = CustomPairPlot() cp.pairanalyzer(titanic, hue='survived') 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 data 必須 pd.DataFrame - 入力データ hue オプション str None 色分けに指定するカラム名 (Noneなら色分けなし) palette オプション str None hueによる色分け用のカラーパレット vars オプション list[str] None グラフ化するカラム名 (Noneなら全ての数値型&Boolean型の列を使用) lowerkind オプション str 'boxscatter' 左下に表示するグラフ種類 ('boxscatter', 'scatter', or 'reg') diag_kind オプション str 'kde' 対角に表示するグラフ種類 ('kde' or 'hist') markers オプション str or list[str] None hueで色分けしたデータの散布図プロット形状 height オプション float 2.5 グラフ1個の高さ aspect オプション float 1 グラフ1個の縦横比 dropna オプション bool True seaborn.PairGridのdropna引数 lower_kws オプション dict {} seaborn.PairGrid.map_lowerの引数 diag_kws オプション dict {} seaborn.PairGrid.map_diag引数 grid_kws オプション dict {} seaborn.PairGridの上記以外の引数 CustomPairPlotクラス使用法詳細 こちらの記事にまとめました https://qiita.com/c60evaporator/items/20f11b6ee965cec48570 histクラス (custom_hist_plot.py) ヒストグラム表示および各種分布のフィッティングを実行します。 ・histクラス内のメソッド一覧 メソッド名 機能 plot_normality 正規性検定とQQプロット fit_dist 各種分布のフィッティングと、評価指標(RSS, AIC, BIC)の算出 plot_normalityメソッド 実行例 from seaborn_analyzer import hist from sklearn.datasets import load_boston import pandas as pd df = pd.DataFrame(load_boston().data, columns= load_boston().feature_names) hist.plot_normality(df, x='LSTAT', norm_hist=False, rounddigit=5) 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 data 必須 pd.DataFrame, pd.Series, or pd.ndarray - 入力データ x オプション str None ヒストグラム作成対象のカラム名 (dataがpd.DataFrameのときは必須) hue オプション str None 色分けに指定するカラム名 (Noneなら色分けなし) binwidth オプション float None ビンの幅 (binsと共存不可) bins オプション int 'auto' ビンの数 (bin_widthと共存不可、'auto'ならスタージェスの公式で自動決定) norm_hist オプション bool False ヒストグラムを面積1となるよう正規化するか? sigmarange オプション float 4 フィッティング線の表示範囲 (標準偏差の何倍まで表示するか指定) linesplit オプション float 200 フィッティング線の分割数 (カクカクしたら増やす) rounddigit オプション int 5 表示指標の小数丸め桁数 hist_kws オプション dict {} matplotlib.axes.Axes.histに渡す引数 subplot_kws オプション dict {} matplotlib.pyplot.subplotsに渡す引数 fit_distメソッド 実行例 from seaborn_analyzer import hist from sklearn.datasets import load_boston import pandas as pd import matplotlib.pyplot as plt from scipy import stats df = pd.DataFrame(load_boston().data, columns= load_boston().feature_names) all_params, all_scores = hist.fit_dist(df, x='LSTAT', dist=['norm', 'gamma', 'lognorm', 'uniform']) df_scores = pd.DataFrame(all_scores).T df_scores 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 data 必須 pd.DataFrame, pd.Series, or pd.ndarray - 入力データ x オプション str None ヒストグラム作成対象のカラム名 (dataがpd.DataFrameのときは必須) hue オプション str None 色分けに指定するカラム名 (Noneなら色分けなし) binwidth オプション float None ビンの幅 (binsと共存不可) bins オプション int 'auto' ビンの数 (bin_widthと共存不可、'auto'ならスタージェスの公式で自動決定) norm_hist オプション bool False ヒストグラムを面積1となるよう正規化するか? sigmarange オプション float 4 フィッティング線の表示範囲 (標準偏差の何倍まで表示するか指定) linesplit オプション float 200 フィッティング線の分割数 (カクカクしたら増やす) dist オプション str or list[str] 'norm' 分布の種類 ('norm', 'lognorm', 'gamma', 't', 'expon', 'uniform', 'chi2', 'weibull') ax オプション matplotlib.axes.Axes None 表示対象のax (Noneならmatplotlib.pyplot.plotで1枚ごとにプロット) linecolor オプション str or list[str] 'red' フィッティング線の色指定 (listで複数指定可) floc オプション float None フィッティング時のX方向オフセット (Noneなら指定なし(weibullとexponは0)) hist_kws オプション dict {} matplotlib.axes.Axes.histに渡す引数 histクラス使用法詳細 こちらの記事にまとめました https://qiita.com/c60evaporator/items/fc531aff0cdbafac0f42 classplotクラス 分類の決定境界およびクラス確率の表示を実行します。 Scikit-Learn APIに対応した分類モデル (例: XGBoostパッケージのXGBoostClassifierクラス)が表示対象となります ・classplotクラス内のメソッド一覧 メソッド名 機能 class_separator_plot 決定境界プロット class_proba_plot クラス確率プロット class_separator_plotメソッド 実行例 import seaborn as sns from sklearn.svm import SVC from seaborn_analyzer import classplot iris = sns.load_dataset("iris") model = SVC() classplot.class_separator_plot(model, ['petal_width', 'petal_length'], 'species', iris) 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 model 必須 Scikit-learn API - 表示対象の回帰モデル x 必須 list[str] - 説明変数に指定するカラム名 y 必須 str - 目的変数に指定するカラム名 data 必須 pd.DataFrame - 入力データ x_chart オプション   list[str] None 説明変数のうちグラフ表示対象のカラム名 pair_sigmarange オプション float 1.5 グラフ非使用変数の分割範囲 pair_sigmainterval オプション float 0.5 グラフ非使用変数の1枚あたり表示範囲 chart_extendsigma オプション float 0.5 グラフ縦軸横軸の表示拡張範囲 chart_scale オプション int 1 グラフの描画倍率 plot_scatter オプション str 'true' 散布図の描画種類 rounddigit_x3 オプション int 2 グラフ非使用軸の小数丸め桁数 scatter_colors オプション list[str] None クラスごとのプロット色のリスト true_marker オプション str None 正解クラスの散布図プロット形状 false_marker オプション str None 不正解クラスの散布図プロット形状 cv オプション int or sklearn.model _selection.* None クロスバリデーション分割法 (Noneのとき学習データから指標算出、int入力時はkFoldで分割) cv_seed オプション int 42 クロスバリデーションの乱数シード cv_group オプション str None GroupKFold,LeaveOneGroupOutのグルーピング対象カラム名 display_cv_indices オプション int 0 表示対象のクロスバリデーション番号 model_params オプション dict None 分類モデルに渡すパラメータ fit_params オプション dict None 学習時のパラメータ subplot_kws オプション dict None matplotlib.pyplot.subplotsに渡す引数 contourf_kws オプション dict None グラフ表示用のmatplotlib.pyplot.contourfに渡す引数 scatter_kws オプション dict None 散布図用のmatplotlib.pyplot.scatterに渡す引数 class_proba_plotメソッド 実行例 import seaborn as sns from sklearn.svm import SVC from seaborn_analyzer import classplot iris = sns.load_dataset("iris") model = SVC() classplot.class_proba_plot(model, ['petal_width', 'petal_length'], 'species', iris, proba_type='imshow') 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 model 必須 Scikit-learn API - 表示対象の回帰モデル x 必須 list[str] - 説明変数に指定するカラム名 y 必須 str - 目的変数に指定するカラム名 data 必須 pd.DataFrame - 入力データ x_chart オプション    list[str] None 説明変数のうちグラフ表示対象のカラム名 pair_sigmarange オプション float 1.5 グラフ非使用変数の分割範囲 pair_sigmainterval オプション float 0.5 グラフ非使用変数の1枚あたり表示範囲 chart_extendsigma オプション float 0.5 グラフ縦軸横軸の表示拡張範囲 chart_scale オプション int 1 グラフの描画倍率 plot_scatter オプション str 'true' 散布図の描画種類 rounddigit_x3 オプション int 2 グラフ非使用軸の小数丸め桁数 scatter_colors オプション list[str] None クラスごとのプロット色のリスト true_marker オプション str None 正解クラスの散布図プロット形状 false_marker オプション str None 不正解クラスの散布図プロット形状 cv オプション int or sklearn.model _selection.* None クロスバリデーション分割法 (Noneのとき学習データから指標算出、int入力時はkFoldで分割)           cv_seed オプション int 42 クロスバリデーションの乱数シード cv_group オプション str None GroupKFold, LeaveOneGroupOutのグルーピング対象カラム名 display_cv_indices オプション int 0 表示対象のクロスバリデーション番号 model_params オプション dict None 分類モデルに渡すパラメータ fit_params オプション dict None 学習時のパラメータ subplot_kws オプション dict None matplotlib.pyplot.subplotsに渡す引数 contourf_kws オプション dict None proba_type='contour'のときmatplotlib.pyplot.contourf、 proba_type='contour'のときcontour)に渡す引数 scatter_kws オプション dict None 散布図用のmatplotlib.pyplot.scatterに渡す引数 plot_border オプション bool True 決定境界線の描画有無 proba_class オプション str or list[str] None 確率表示対象のクラス名 proba_cmap_dict オプション dict[str, str] None クラス確率図のカラーマップ(クラス名とcolormapをdict指定) proba_type オプション str 'contourf' クラス確率図の描画種類(等高線'contourf', 'contour', or RGB画像'imshow') imshow_kws オプション dict None proba_type='imshow'のときmatplotlib.pyplot.imshowに渡す引数 classplotクラス使用法詳細 こちらの記事にまとめました https://qiita.com/c60evaporator/items/43866a42e09daebb5cc0 regplotクラス 相関・回帰分析の散布図・ヒートマップ表示を実行します。 Scikit-Learn APIに対応した回帰モデル (例: XGBoostパッケージのXGBoostRegressorクラス)が表示対象となります ・regplotクラス内のメソッド一覧 メソッド名 機能 linear_plot ピアソン相関係数とP値を散布図と共に表示 regression_pred_true 予測値vs実測値プロット regression_plot_1d 1次元説明変数で回帰線表示 regression_heat_plot 2~4次元説明変数で回帰予測値をヒートマップ表示 linear_plotメソッド 実行例 from seaborn_analyzer import regplot import seaborn as sns iris = sns.load_dataset("iris") regplot.linear_plot(x='petal_length', y='sepal_length', data=iris) 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 x 必須 str - 横軸に指定するカラム名 y 必須 str - 縦軸に指定するカラム名 data 必須 pd.DataFrame - 入力データ ax オプション matplotlib.axes.Axes None 表示対象のAxes (Noneならmatplotlib.pyplot.plotで1枚ごとにプロット) hue オプション str None 色分けに指定するカラム名 linecolor オプション str 'red' 回帰直線の色 rounddigit オプション int 5 表示指標の小数丸め桁数 plot_scores オプション bool True 回帰式、ピアソンの相関係数およびp値の表示有無 scatter_kws オプション dict None seaborn.scatterplotに渡す引数 regression_pred_trueメソッド 実行例 import pandas as pd from seaborn_analyzer import regplot import seaborn as sns from sklearn.linear_model import LinearRegression df_temp = pd.read_csv(f'./sample_data/temp_pressure.csv') regplot.regression_pred_true(LinearRegression(), x=['altitude', 'latitude'], y='temperature', data=df_temp) 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 model 必須 Scikit-learn API - 表示対象の回帰モデル x 必須 list[str] - 説明変数に指定するカラム名のリスト y 必須 str - 目的変数に指定するカラム名 data 必須 pd.DataFrame - 入力データ hue オプション    str None 色分けに指定するカラム名 linecolor オプション str 'red' 予測値=実測値の線の色 rounddigit オプション int 3 表示指標の小数丸め桁数 rank_number オプション int None 誤差上位何番目までを文字表示するか rank_col オプション str None 誤差上位と一緒に表示するフィールド名 scores オプション str or list[str] 'mae' 文字表示する評価指標を指定 ('r2', 'mae', 'rmse', 'rmsle', or 'max_error') cv_stats オプション str 'mean' クロスバリデーション時に表示する評価指標統計値 ('mean', 'median', 'max', or 'min') cv オプション int or sklearn.model _selection.* None クロスバリデーション分割法 (Noneのとき学習データから指標算出、int入力時はkFoldで分割)           cv_seed オプション int 42 クロスバリデーションの乱数シード model_params オプション dict None 回帰モデルに渡すパラメータ fit_params オプション dict None 学習時のパラメータをdict指定 subplot_kws オプション dict None matplotlib.pyplot.subplotsに渡す引数 scatter_kws オプション dict None seaborn.scatterplotに渡す引数 regression_plot_1dメソッド 実行例 from seaborn_analyzer import regplot import seaborn as sns from sklearn.svm import SVR iris = sns.load_dataset("iris") regplot.regression_plot_1d(SVR(), x='petal_length', y='sepal_length', data=iris) 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 model 必須 Scikit-learn API - 表示対象の回帰モデル x 必須 str - 説明変数に指定するカラム名 y 必須 str - 目的変数に指定するカラム名 data 必須 pd.DataFrame - 入力するデータ(Pandasのデータフレーム) hue オプション    str None 色分けに指定するカラム名 linecolor オプション str 'red' 予測値=実測値の線の色 rounddigit オプション int 3 表示指標の小数丸め桁数 rank_number オプション int None 誤差上位何番目までを文字表示するか rank_col オプション str None 誤差上位と一緒に表示するフィールド名 scores オプション str or list[str] 'mae' 文字表示する評価指標を指定 ('r2', 'mae', 'rmse', 'rmsle', or 'max_error') cv_stats オプション str 'mean' クロスバリデーション時に表示する評価指標統計値 ('mean', 'median', 'max', or 'min') cv オプション int or sklearn.model _selection.* None クロスバリデーション分割法 (Noneのとき学習データから指標算出、int入力時はkFoldで分割)           cv_seed オプション int 42 クロスバリデーションの乱数シード model_params オプション dict None 回帰モデルに渡すパラメータ fit_params オプション dict None 学習時のパラメータをdict指定 subplot_kws オプション dict None matplotlib.pyplot.subplotsに渡す引数 scatter_kws オプション dict None seaborn.scatterplotに渡す引数 regression_heat_plotメソッド 実行例 import pandas as pd from sklearn.linear_model import LinearRegression from seaborn_analyzer import regplot df_temp = pd.read_csv(f'./sample_data/temp_pressure.csv') regplot.regression_heat_plot(LinearRegression(), x=['altitude', 'latitude'], y='temperature', data=df_temp) 引数一覧 引数名 必須引数orオプション 型 デフォルト値 内容 model 必須 Scikit-learn API - 表示対象の回帰モデル x 必須 list[str] - 説明変数に指定するカラム名のリスト y 必須 str - 目的変数に指定するカラム名 data 必須 pd.DataFrame - 入力データ x_heat オプション    list[str] None 説明変数のうちヒートマップ表示対象のカラム名 scatter_hue オプション str None 散布図色分け指定カラム名 (plot_scatter='hue'時のみ有効) pair_sigmarange オプション float 1.5 ヒートマップ非使用変数の分割範囲 pair_sigmainterval オプション float 0.5 ヒートマップ非使用変数の1枚あたり表示範囲 heat_extendsigma オプション float 0.5 ヒートマップ縦軸横軸の表示拡張範囲 heat_division オプション int 30 ヒートマップ縦軸横軸の解像度 value_extendsigma オプション float 0.5 ヒートマップの色分け最大最小値拡張範囲 plot_scatter オプション str 'true' 散布図の描画種類 rounddigit_rank オプション int 3 誤差上位表示の小数丸め桁数 rounddigit_x1 オプション int 2 ヒートマップ横軸の小数丸め桁数 rounddigit_x2 オプション int 2 ヒートマップ縦軸の小数丸め桁数 rounddigit_x3 オプション int 2 ヒートマップ非使用軸の小数丸め桁数 rank_number オプション int None 誤差上位何番目までを文字表示するか rank_col オプション str None 誤差上位と一緒に表示するフィールド名 cv オプション int or sklearn.model _selection.* None クロスバリデーション分割法 (Noneのとき学習データから指標算出、int入力時はkFoldで分割)           cv_seed オプション int 42 クロスバリデーションの乱数シード display_cv_indices オプション int 0 表示対象のクロスバリデーション番号 model_params オプション dict None 回帰モデルに渡すパラメータ fit_params オプション dict None 学習時のパラメータをdict指定 subplot_kws オプション dict None matplotlib.pyplot.subplotsに渡す引数 heat_kws オプション dict None ヒートマップ用のseaborn.heatmapに渡す引数 scatter_kws オプション dict None 散布図用のmatplotlib.pyplot.scatterに渡す引数 regplotクラス使用法詳細 こちらの記事にまとめました https://qiita.com/c60evaporator/items/c930c822b527f62796ee おわりに データ分析をしていれば、これらの機能がいずれ使える場面があるかと思うので、ぜひ活用いただければと思います。 もしこのツールを良いと思われたら、GitHubにStar頂けるとありがたいです!
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

AtCoder Beginner Contest 212 参戦記

AtCoder Beginner Contest 212 参戦記 ABC212A - Alloy 2分半で突破. 書くだけ. A, B = map(int, input().split()) if A > 0 and B == 0: print('Gold') elif A == 0 and B > 0: print('Silver') elif A > 0 and B > 0: print('Alloy') ABC212B - Weak Password 4分で突破. 書くだけ. X = list(map(int, input())) if X.count(X[0]) == 4 or all(X[i + 1] == (X[i] + 1) % 10 for i in range(3)): print('Weak') else: print('Strong') ABC212C - Min Difference 12分半で突破. 単純に2重ループの O(NM) で駄目なら、にぶたんで O((N+M)logM) にすることを思いつく. Ai 以下の B のインデックスを見つけて、その前後だけ差分を取ればいい. from bisect import bisect_left N, M = map(int, input().split()) A = list(map(int, input().split())) B = list(map(int, input().split())) B.sort() result = 10 ** 9 for a in A: i = bisect_left(B, a) if i != 0: result = min(result, abs(a - B[i - 1])) if i != M: result = min(result, abs(a - B[i])) print(result) A も小さい順に並べてしまえば、にぶたんを使わなくても順次 Ai 以下の B のインデックスを見つけることができる. N, M = map(int, input().split()) A = list(map(int, input().split())) B = list(map(int, input().split())) B.sort() result = 10 ** 9 i = 0 for a in sorted(A): while i != M and B[i] < a: i += 1 if i != 0: result = min(result, abs(a - B[i - 1])) if i != M: result = min(result, abs(a - B[i])) print(result) ABC212D - Querying Multiset 6分で突破. 操作2をバカ正直にやると O(Q2) になって TLE. 袋の中の数字と袋の外の補正値を足すと本当の数値になるという持ち方にすると操作2が O(1) になる. 操作3をバカ正直にやるとやっぱり O(Q2) になって TLE するので、優先度付きキューか平衡二分探索木の出番です. Python には組み込みの平衡二分探索木がないので heapq になるいつもの奴. from sys import stdin from heapq import heappop, heappush readline = stdin.readline Q = int(readline()) t = 0 result = [] bag = [] for _ in range(Q): query = readline() if query[0] == '1': _, X = map(int, query.split()) heappush(bag, X - t) elif query[0] == '2': _, X = map(int, query.split()) t += X elif query[0] == '3': result.append(heappop(bag) + t) print(*result, sep='\n')
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

~ ダイクストラ法 ~ チートシート

目次 ダイクストラ法とは 実装 もしかして:ワーシャルフロイド法 はじめに チートシートの扱いついてはここを読んでください ダイクストラ法とは わかりやすいサイト 指定した頂点を始点とした場合に、最も到達にかかるコストが低い頂点から順に確定させていくことで、各頂点に到達するのに必要な最低コストを求めるアルゴリズム 始点となる頂点が1つに決まっている場合に有効 計算量が大きいので、始点が全頂点の場合はワーシャルフロイド法を使う 負のコストがある場合には使えないので注意 実装 問題(Atcoderだと簡単な問題がなかったので、AIZU ONLINE JUDGEより) Dijkstra's_algorithm.py import heapq N,M,S = map(int,input().split()) #頂点の数、辺の数、スタートとなる頂点の番号 root = [] #辺に関する情報を格納する配列 start = [[] for i in range(N+1)] #辺を始点ごとに分類する配列 for i in range(M): A,B,C = map(int,input().split()) #辺の始点、辺の終点、移動に必要なコスト root.append([A,B,C]) start[A].append(i) check = [] #最小コストの候補を格納する配列 heapq.heapify(check) heapq.heappush(check, [0,S]) ans = [-1]*(N+1) #各頂点の最小コストを格納する配列 while True: cost,now = heapq.heappop(check) if ans[now] == -1: ans[now] = cost for i in range(len(start[now])): heapq.heappush(check, [cost+root[start[now][i]][2], root[start[now][i]][1]]) if check == []: break for i in range(N): #for i in range(1,N+1): #頂点の番号が1,2,3...と振られている場合(Atcoder用) if ans[i] != -1: print(ans[i]) else: print("INF") Atcoderと異なりインデックスが0始まりなので注意(Pythonと同じだからこっちの方がありがたいけど) 高速化のために優先度付きキューを利用 ライブラリshortest_path()を使うという手もあるけど、Atcoderだとなんかめんどくさいらしい
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

【AtCoder解説】PythonでABC212のA,B,C,D,E問題を制する!

ABC212のA,B,C,D,E問題を、Python3でなるべく丁寧に解説していきます。 ただ解けるだけの方法ではなく、次の3つのポイントを満たす解法を解説することを目指しています。 シンプル:余計なことを考えずに済む 実装が楽:ミスやバグが減ってうれしい 時間がかからない:パフォが上がって、後の問題に残せる時間が増える ご質問・ご指摘はコメントかツイッター、その他のご意見・ご要望などはマシュマロまでお気軽にどうぞ! Twitter: u2dayo マシュマロ: https://marshmallow-qa.com/u2dayo ほしいものリスト : プレゼントしていただけると、やる気が出ます! よかったらLGTMや拡散していただけると喜びます! 目次 ABC212 まとめ A問題『Alloy』 B問題『Weak Password』 C問題『Min Difference』 D問題『Querying Multiset 』 E問題『Safety Journey』 アプリ AtCoderFacts を開発しています コンテストの統計データを見られるアプリ『AtCoderFacts』を作りました。 現在のところ、次の3つのデータを見ることができます。 - レート別問題正解率 - パフォーマンス目安 - 早解きで上昇するパフォーマンス 今後も機能を追加していく予定です。使ってくれると喜びます。 ABC212 まとめ 全提出人数: 7893人 パフォーマンス パフォ AC 点数 時間 順位(Rated内) 200 AB------ 300 27分 5662(5431)位 400 ABC----- 600 97分 4616(4386)位 600 ABC----- 600 42分 3785(3557)位 800 ABC----- 600 13分 2950(2725)位 1000 ABCD---- 1000 67分 2192(1969)位 1200 ABCD---- 1000 36分 1573(1355)位 1400 ABCDE--- 1500 125分 1101(891)位 1600 ABCDE--- 1500 63分 759(555)位 1800 ABCDE--- 1500 40分 511(321)位 2000 ABCDE--- 1500 21分 340(171)位 2200 ABCDE-G- 2100 105分 212(84)位 2400 ABCDE-G- 2100 76分 133(40)位 色別の正解率 色 人数 A B C D E F G H 灰 3611 98.4 % 71.8 % 36.7 % 7.8 % 0.7 % 0.0 % 0.1 % 0.0 % 茶 1345 99.5 % 98.1 % 85.6 % 40.1 % 2.7 % 0.1 % 0.3 % 0.1 % 緑 1046 99.6 % 98.8 % 97.2 % 73.4 % 15.7 % 0.4 % 0.4 % 0.0 % 水 635 99.5 % 99.2 % 98.9 % 90.7 % 56.1 % 0.9 % 3.6 % 0.0 % 青 372 100.0 % 99.7 % 100.0 % 97.6 % 85.5 % 16.1 % 19.4 % 2.1 % 黄 176 94.3 % 93.8 % 93.8 % 94.3 % 88.1 % 25.6 % 53.4 % 7.4 % 橙 37 97.3 % 97.3 % 97.3 % 97.3 % 97.3 % 48.6 % 73.0 % 32.4 % 赤 19 89.5 % 89.5 % 89.5 % 94.7 % 89.5 % 89.5 % 100.0 % 63.2 % ※表示レート、灰に初参加者は含めず A問題『Alloy』 問題ページ:A - Alloy 灰コーダー正解率:98.4 % 茶コーダー正解率:99.5 % 緑コーダー正解率:99.6 % 考察 問題文どおりにif文を書けばいいです。条件は以下の3つです。 $0 < A$ かつ $ B = 0$ なら Gold $A = 0$ かつ $0 < B$ なら Silver $0 < A $ かつ $0 < B $ なら Alloy $0\lt{A+B}$​ ですから、$A=0$​ かつ $B = 0$​​ はありません。(無から無を生成するのは意味がわかりませんね) よく考えると、$A,B$ それぞれが $0$ かどうかだけ判定すれば良いので、以下のように判定を簡略化できます。 $B=0$ なら Gold $A=0$ なら Silver​ 上 $2$ つのどちらでもないなら Alloy コード def main(): A, B = map(int, input().split()) if B == 0: print("Gold") elif A == 0: print("Silver") else: print("Alloy") if __name__ == '__main__': main() B問題『Weak Password』 問題ページ:B - Weak Password 灰コーダー正解率:71.8 % 茶コーダー正解率:98.1 % 緑コーダー正解率:98.8 % 考察 問題文の意味がわかりづらいので、$20$ 種類しかない弱いパスワードを全列挙してみます。 全部同じ数字: 0000,1111,2222,3333,4444,5555,6666,7777,8888,9999 数字が4つ連続している: 0123,1234,2345,3456,4567,5678,6789,7890,8901,9012 これらを直接打ち込んで比較してもACできますが、for文などでこれらの文字列を作ったほうが楽です。 コード 20種類の弱いパスワードを書いて比較するコード 打ち間違えの恐れがあるので、おすすめはしません。 def main(): X = input() P = ['0000', '1111', '2222', '3333', '4444', '5555', '6666', '7777', '8888', '9999', '0123', '1234', '2345', '3456', '4567', '5678', '6789', '7890', '8901', '9012'] print('Weak' if X in P else 'Strong') if __name__ == '__main__': main() 20種類の弱いパスワードをfor文で生成して比較するコード def main(): X = input() P = [str(i) * 4 for i in range(10)] # とりあえず0000~9999を作る # ここで0123~9012を作る for i in range(10): s = "" for j in range(4): s += str((i + j) % 10) # 9の次は0なので、i + j = 10のとき0にしたい P.append(s) print('Weak' if X in P else 'Strong') if __name__ == '__main__': main() 真面目に判定するコード def main(): def solve(): if len(set(X)) == 1: return 'Weak' # Xの文字種が1種類のみです for i in range(3): if (int(X[i]) + 1) % 10 != int(X[i + 1]): return 'Strong' # 1箇所でも連続していなければ、その時点で強いパスワードです return 'Weak' # 1234のような4連続の数字なので、弱いパスワードです X = input() print(solve()) if __name__ == '__main__': main() C問題『Min Difference』 問題ページ:C - Min Difference 灰コーダー正解率:36.7 % 茶コーダー正解率:85.6 % 緑コーダー正解率:97.2 % 考察 もちろん、すべての $A_i$​ と $B_j$ の組を試すとTLEになります。 ある $A_i$​​ を $1$​ つに対して、数列 $B$​​​ の中で『差の絶対値の最小』になる相手を高速に探すことを考えます。これができれば、数列$A$ の $N$ 個ある全要素に対して同様に 『差の絶対値の最小』 を求めることで、この問題を解くことができます。 そのために、二分探索を使います。 二分探索を使う 二分探索は、ソートされた数列 $L$ にある値 $x$ を挿入するとき、$x$ が何番目に入るかを $O(log\, N)$​​​で求めることができます。(二分探索の詳しい説明はググってください) 下図の数列$B$ の $5$ と $9$ の間に $A_i$ が入るとわかったとき、『差の絶対値の最小』の候補は、すぐ左の $|A_i - 5|$ か、すぐ右の $|9-A_i|$​です。これらより遠くにある$|A_i-2|$ や $|13-A_i|$ は最小値にはなり得ません。そのため、両隣 $2$ 箇所だけ計算してみて、小さいほうが答えです。 まとめ さて、この問題を解くアルゴリズムは以下の通りです。 数列 $B$​​ をソートする forループを使って、すべての $A_i$​ で $B$​ を二分探索して、すべての $A_i$​ に対する最適解を求める そのうち最小のものが答え ただし、$B$ の最大値よりも $A_i$ が大きい場合と、$B$ の最小値よりも $A_i$ が小さい場合に、配列外参照を起こさないように注意してください。(if文でチェックします) 二分探索 $1$​​ 回の計算量は$O(log\, M)$​​ なので、長さ$N$ の 数列$A$ の要素すべてについて行うと $O(N log\,M)$ です。また、$B$ のソートの計算量は$O(Mlog\,M)$ です。したがって、全体の計算量は$O((N+M)log\,M)$​​です。 コード def main(): import bisect N, M = map(int, input().split()) A = list(map(int, input().split())) B = list(map(int, input().split())) B.sort() # 二分探索するので、Bはソートしておきます INF = float('INF') ans = INF # 正の無限大で初期化します for a in A: i = bisect.bisect_left(B, a) # bisect_rightでもいいです # 配列外やB[-1]を参照するのを防ぐために、if文を使います if 0 <= i - 1 < M: b1 = B[i - 1] ans = min(ans, abs(a - b1)) if 0 <= i < M: b2 = B[i] ans = min(ans, abs(a - b2)) print(ans) if __name__ == '__main__': main() D問題『Querying Multiset 』 問題ページ:D - Querying Multiset 灰コーダー正解率:7.8 % 茶コーダー正解率:40.1 % 緑コーダー正解率:73.4 % 考察 $3$ つの操作を簡単に書き直します。 操作 $1$ : $X_i$ が書かれたボールを袋に入れる 操作 $2$ : 袋に入っているすべてのボールの数字に $X_i$ を足す 操作 $3$​​ : 袋から『数字が一番小さいボール』を取り出して、書かれている数字を出力する 操作2さえなければ簡単だが 袋全体のボールの数字を書き換える操作 $2$ さえなければ、優先度付きキュー(Priority Queue, Pythonではheapqモジュール)を使うだけで簡単に解けます。 優先度付きキューは、次の操作を以下の計算量で行えるデータ構造です。要素の追加や削除があるときに、ソートの代わりに使うものだと思ってください。 最小の要素の取得:$O(1)$ 最小の要素の削除:$O(log\,N)$ 要素の追加:$O(log\,N)$​ もし操作 $2$ がなければ、操作 $1$ が来たら要素を追加、操作 $3$ が来たら要素を取得・出力・削除をするだけで解けてしまいます。 操作2 の処理を考える しかし、操作 $2$​ が来るたびに優先度付きキュー内の数字をすべて書き換えるとTLEになります。($10$ 万個のボールの数字を $10$ 万回書き換えるケースができます)そもそも、優先度付きキューは通常の実装では要素の書き換えをすることができません。 そこで、操作 $2$​​ が来たとき、優先度キュー内の要素は一切変更せず、代わりに全体に足された値を変数 $S$​​​ で管理することにします。そして、以下のように操作をすれば、ボールの数字そのものを書き換えることなく、ボールと書かれた数字を優先度付きキューで管理することができます。 はじめ は 袋(優先度付きキュー)は空、$S=0$ である 操作 $2$ では $S$ に $X_i$ を足す($S \leftarrow{S+X_i}$​) 操作 $3$ で袋から最小のボール $X_{min}$ を取り出したあと、$X_{min}+S$​ を出力する($S$ が袋全体に足されているからです) 操作 $1$ で袋にボール $X_i$ を追加するとき、代わりに $X_i - S$ を追加する(このまま $X_i$ を追加すると、袋に$X_i + S$ のボールが追加されてしまいます。そのため、$S$ を引いて $(X_i - S) + S = X_i$ となるように調整してあげます) 計算量は $O(Qlog\, Q)$です。 コード # python標準ライブラリのheapqは大変使いづらいので、使いやすいようクラスにしておくと良いです # なお、このクラスにバグがあっても責任は一切取りません、自分で書いてね class PriorityQueue: def __init__(self, a=None): import heapq self.heapq = heapq self.__container = [] if a: self.__container = a[:] self.heapq.heapify(self.__container) @property def is_empty(self): return not self.__container def pop(self): return self.heapq.heappop(self.__container) def push(self, x): self.heapq.heappush(self.__container, x) def sum(self): return sum(self.__container) def __len__(self): return len(self.__container) def __str__(self): return str(sorted(self.__container)) def __repr__(self): return self.__str__() def main(): Q = int(input()) pq = PriorityQueue() S = 0 for _ in range(Q): q = list(map(int, input().split())) query = q[0] if query == 1: x = q[1] pq.push(x - S) # x + S - S = x です elif query == 2: x = q[1] S += x # 袋全体にx足します elif query == 3: y = pq.pop() print(y + S) # もともとyが書かれたボールに、袋全体のSを足します if __name__ == '__main__': main() E問題『Safety Journey』 問題ページ:E - Safety Journey 灰コーダー正解率:0.7 % 茶コーダー正解率:2.7 % 緑コーダー正解率:15.7 % 考察 無向グラフが与えられます。都市 $1$ からはじめて、$K$ 回移動した後、最後に都市 $1$ に戻るルートの組み合わせ数を求める問題です。$K$ は最大で $5000$ です。 都市の数 $N$​​ は最大 $5000$​​​ です。どの $2$​ つの相違なる都市の間も双方向に通れるとあります。$N\times(N-1)$​ 通りの都市の組があるため、辺の数は最大でおよそ $2500$​​​ 万本になります。 ただし、$N\times{(N-1)}$​​ 本の辺のうち、$M$​​ 本の辺が使えないとあります。$M$​​ は最大で $5000$​​​​ しかありません。 ほとんどの道は使える→使えない道だけ引けばいい この問題は動的計画法で解けます。しかし、$5000$ 回 の 移動のたびに $2500$ 万本の辺すべてを見て遷移をしていては、当然間に合いません。 ここで、ほとんどの道は使うことができて、使えない道は非常に少ないことに着目します。つまり、一旦全ての道を使えるものとして計算したあと、使えない道の分を引けば、高速で解くことができます。 具体的には、以下の操作を$K$​ 回繰り返すことで計算量 $O(K(N+M))$​​ でこの問題を解くことができます。 一旦使えない道のことは忘れて、全都市から移動できることにする(そのために、全都市の組み合わせ数の操作をsum関数で求めておく) 同じ都市から同じ都市に直接移動はできないので、その分を引く 使えない $M$​ 個 の 辺それぞれについて、組み合わせ数を引く 実装 $10^9 + 7$​ ではなく $998244353$​​ で割った余りを求めることに気をつけましょう。 コード Pythonで通すのは大変厳しい(numpyかnumbaを使わないとおそらく無理です)ので、PyPyで提出してください。 MOD = 998244353 # 998244353の余りを答えることに気をつけてください。タイプミスを防ぐために、問題文からコピペしましょう。 def main(): N, M, K = map(int, input().split()) edge = [] for _ in range(M): u, v = map(int, input().split()) edge.append((u, v)) dp_prev = [0] * (N + 1) dp_prev[1] = 1 for _ in range(K): dp_next = [0] * (N + 1) S = sum(dp_prev) % MOD # 前の日の、全都市の組み合わせ数の合計です for i in range(1, N + 1): dp_next[i] = (S - dp_prev[i]) % MOD # i以外の都市からiに向います。都市iから都市iには行けないので引きます for u, v in edge: # u -> v と v -> u が封鎖されていて使えないので、その分引きます dp_next[u] -= dp_prev[v] dp_next[v] -= dp_prev[u] dp_next[u] %= MOD dp_next[v] %= MOD dp_prev = dp_next print(dp_prev[1]) if __name__ == '__main__': main()
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Progateで作ったWebアプリをDjangoで作ってみる2! Part6 -base.html編-

目標物の確認 ProgateのNode.jsコースで作ったブログアプリと同じものをDjangoで作ってみます。 Djangoでのアプリ開発の一連の流れを整理するために記していきます。 完成イメージ レイアウトの使いまわし どのHTMLファイルにも同じレイアウト部分がある場合、その部分だけ使い回すことができます。 そうすることで、共通部分のレイアウトに変更が生じた場合、大元のHTMLファイルだけ変更することで、個々のページを変更する必要がなくなります。そっちのほうが効率的です。 今回は下の赤枠部分がいくつかのページで使い回されています。 Djangoではテンプレートの継承(Template inheritance)と呼ばれています。 Template inheritance 定形のhtmlファイルを作成して、それを個々のHTMLファイルで読み込みます。 ベースとなるHTMLファイル、今回は赤枠の部分をbase.htmlとして作ります。 base.htmlファイルが全体のフレームワークとなり、個々のページの中(このあと作っていくlist.html、article.html)で{% block content %}の情報を入れていきます(☆1)。 base.htmlの中を見てみます。ログインしているときは、「ようこそ、ユーザー名さん」が表示されるようになり、ログインしていないときは「ようこそ、ゲストさん」と表示されるようにしました(☆2)。 さらに、ログインしているときは、「ログアウト」が、ログインしていないときは「新規登録」と「ログイン」が表示されるようにしました(☆3)。 base.html {% load static %} <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title>BLOG</title> <link rel="stylesheet" href="{% static 'blog/style.css' %}"> </head> <body> <header> <div class="header-nav"> <a href="/">BLOG</a> {% if user.is_authenticated %} # ☆2 <p>ようこそ、{{ user.username }}さん</p> # ☆2 {% else %} # ☆2 <p>ようこそ、ゲストさん</p> {% endif %} # ☆2 <ul> <li><a href="{% url 'blog:list' %}">記事一覧</a></li> {% if user.is_authenticated %} # ☆3 <li><a href="{% url 'blog:logout' %}">ログアウト</a></li> {% else %} # ☆3 <li><a href="{% url 'blog:signup' %}">新規登録</a></li> <li><a href="{% url 'blog:login' %}">ログイン</a></li> {% endif %} # ☆3 </ul> </div> <p>わんこの学びブログ</p> </header> </body> {% block content %} <!--☆1--> {% endblock content %} <!--☆1--> 個別のlist.htmlファイルを見てみます。 {% extends 'blog/base.html' %}の部分がポイントです。このコードはbase.htmlに記載されている内容をベースとして広げて使っていくイメージです(☆4)。base.htmlの中で定義した{% block content %} {% endblock content %}の中に個々のページのコードを書いていく感じです(☆5)。 ブログ記事のカテゴリーが全員か会員限定かで表示が変わるようにします(☆6)。 list.html {% extends 'blog/base.html' %} <!--☆4--> {% block content %} <!--☆5--> <main> <ul class="list"> {% for object in object_list %} <li> {% if object.category == 'limited' %} <!--☆6--> <i>会員限定</i> {% endif %} <!--☆6--> <h2>{{ object.title }}</h2> <p>{{ object.summary }}</p> <a href="{% url 'blog:article' object.pk %}">続きを読む</a> </li> {% endfor %} <!--☆6--> </ul> </main> {% endblock content %} <!--☆5--> ブログ記事一覧を表示するため、ListViewを使います(☆7)。 blogapp/blog/views.py from blog.models import BlogModel from django.shortcuts import render, redirect from django.views.generic import TemplateView, ListView # ☆7 from django.contrib.auth.models import User from django.db import IntegrityError from django.contrib.auth import authenticate, login, logout from .models import BlogModel class BlogTop(TemplateView):... def signupview(request):... def logoutview(request):... class BlogList(ListView): # ☆7 template_name = 'blog/list.html' model = BlogModel urls.pyを編集します(☆8)。 blogapp/blog/urls.py from django.urls import path from .views import BlogTop, signupview, loginview, logoutview, BlogList # ☆8 app_name = 'blog' urlpatterns = [ path('', BlogTop.as_view(), name='top'), path('signup/', signupview, name='signup'), path('login/', loginview, name='login'), path('logout/', logoutview, name='logout'), path('list/', BlogList.as_view(), name='list'), # ☆8 ] 完成。
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む