Note: Students should always aim to produce
publication-worthy tables and figures. Unless otherwise stated,
tables should be rendered using stargazer::(), while
figures can be rendered using ggplot2::() or
plot(). Regardless, tables and figures should always be
presented with necessary formatting – e.g., (sub)title, axis (variable)
labels and titles, a clearly-identifiable legend and key, etc. Problem
sets must always be compiled using LaTex or
RMarkdown and include the full coding routine (with notes
explaining your implementation) used to complete each problem
(10pts).
ag_news dataset from the
textdata package in R
(textdata::dataset_ag_news()), construct a Continuous Bag
of Words (CBOW) model. With that model, use predict() to
recover the 5 nearest words for a series of 5 other words that you
choose.ag <- textdata::dataset_ag_news()
ag$title_clean <- reduce_complexity(ag$title)
cbow_model = word2vec(x = ag$title, # Ag Title Clean
type = "cbow", # Cont. Bag of Words
dim = 15, # 15 Dimensions
iter = 20) # 20 Iterations (max)
cbow_lookslike <- predict(cbow_model, c("soccer", "champion", "war", "mars", "river"), type = "nearest", top_n = 5)
print(cbow_lookslike)
## $soccer
## term1 term2 similarity rank
## 1 soccer bronze 0.9598973 1
## 2 soccer final 0.9452686 2
## 3 soccer qualifying 0.9443066 3
## 4 soccer tennis 0.9433466 4
## 5 soccer qualifier 0.9430767 5
##
## $champion
## term1 term2 similarity rank
## 1 champion defending 0.9639828 1
## 2 champion champ 0.9475932 2
## 3 champion Noguchi 0.9278772 3
## 4 champion cycling 0.9270868 4
## 5 champion guessing 0.9242694 5
##
## $war
## term1 term2 similarity rank
## 1 war elections 0.9542273 1
## 2 war action 0.9511375 2
## 3 war ban 0.9296306 3
## 4 war state 0.9282848 4
## 5 war warned 0.9273162 5
##
## $mars
## term1 term2 similarity rank
## 1 mars resort 0.8856683 1
## 2 mars Bangladeshi 0.8798008 2
## 3 mars Caucasus 0.8784330 3
## 4 mars condemns 0.8760334 4
## 5 mars reporter 0.8669539 5
##
## $river
## term1 term2 similarity rank
## 1 river quakes 0.9323485 1
## 2 river train 0.9317452 2
## 3 river earthquake 0.9278418 3
## 4 river landslides 0.9272598 4
## 5 river trains 0.9270689 5
SBERT and the ag_news dataset from
the textdata package in R
(textdata::dataset_ag_news()), recover both the embeddings
from a sample of 1000 title. Afterwards, recover pairwise cosine
similariy scores and report the top and bottom-5 most similar using
stargazer (5pts).set.seed(1234)
ag <- textdata::dataset_ag_news() %>%
dplyr::sample_n(100) %>%
mutate(row_id = row_number())
sentences_pairwise <- expand.grid(row_i = ag$row_id,
row_j = ag$row_id) %>%
filter(row_i != row_j) %>% # remove self-pairs
left_join(ag, by = c("row_i" = "row_id")) %>%
rename(text_i = title) %>%
left_join(ag, by = c("row_j" = "row_id")) %>%
rename(text_j = title) %>%
rename_with(~ gsub("\\.x$", "_i", .x)) %>%
rename_with(~ gsub("\\.y$", "_j", .x))
library(reticulate) # Activate Reticulate
## Warning: package 'reticulate' was built under R version 4.4.3
virtualenv_create()# Create Virtual Environment (If Needed)
## virtualenv: ~/.virtualenvs/r-reticulate
use_virtualenv(required = TRUE) # Activate Environment
required_packages <- c("torch", "transformers", "sentence-transformers")
for (pkg in required_packages) {
if (!py_module_available(pkg)) {
virtualenv_install(packages = pkg)
}
} # Install torch, transformers, and sentence-transformers (if needed)
## Using virtual environment "~/.virtualenvs/r-reticulate" ...
## + "C:/Users/jaketruscott/Documents/.virtualenvs/r-reticulate/Scripts/python.exe" -m pip install --upgrade --no-user sentence-transformers
sentence_transformers <- import("sentence_transformers") # Import sentence-transformers
util <- import("sentence_transformers.util") # Util Package
model_name = "sentence-transformers/all-MiniLM-L6-v2" # SBERT
model <- sentence_transformers$SentenceTransformer(model_name) # Declare Model (SBERT)
embeddings <- model$encode(ag$title) # Convert Sentences to Embeddings
cosine_sim <- data.frame() # Output DF
for (i in 1:nrow(sentences_pairwise)){
temp_row <- sentences_pairwise[i,] # Temp Row
i_embeddings <- embeddings[temp_row$row_i,] # i embedding
j_embeddings <- embeddings[temp_row$row_j,] # j embedding
temp_similarity <- util$cos_sim(i_embeddings, j_embeddings)$item() # Cosine Sim
cosine_sim <- bind_rows(cosine_sim,
data.frame(temp_row, similarity = temp_similarity)) # Export
if (i %% 500 == 0){
message('Completed ', i, '/', nrow(sentences_pairwise))
}
}
## Completed 500/9900
## Completed 1000/9900
## Completed 1500/9900
## Completed 2000/9900
## Completed 2500/9900
## Completed 3000/9900
## Completed 3500/9900
## Completed 4000/9900
## Completed 4500/9900
## Completed 5000/9900
## Completed 5500/9900
## Completed 6000/9900
## Completed 6500/9900
## Completed 7000/9900
## Completed 7500/9900
## Completed 8000/9900
## Completed 8500/9900
## Completed 9000/9900
## Completed 9500/9900
top_5 <- cosine_sim %>%
arrange(desc(similarity)) %>%
slice_head(n = 3) # 5 Most Sim
top_5$text_i[1]
## [1] "Crude futures rally over \\$53"
top_5$text_j[1]
## [1] "Crude prices make a retreat"
top_5$similarity[1]
## [1] 0.6020842
bottom_5 <- cosine_sim %>%
arrange(similarity) %>%
slice_head(n = 3) %>%
arrange(desc(similarity)) # 5 Least Sim
bottom_5$text_i[1]
## [1] "State of Emergency Declared in Iraq"
bottom_5$text_j[1]
## [1] "Slower Pub Sales Hit Whitbread"
bottom_5$similarity[1]
## [1] -0.1957191